diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h @@ -13,6 +13,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td @@ -12,6 +12,7 @@ include "mlir/Dialect/Index/IR/IndexDialect.td" include "mlir/Dialect/Index/IR/IndexEnums.td" include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" @@ -23,7 +24,8 @@ /// Base class for Index dialect operations. class IndexOp traits = []> - : Op; + : Op] # traits>; //===----------------------------------------------------------------------===// // IndexBinaryOp diff --git a/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h @@ -0,0 +1,126 @@ +//===- InferIntRangeCommon.cpp - Inference for common ops --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares implementations of range inference for operations that are +// common to both the `arith` and `index` dialects to facilitate reuse. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H +#define MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H + +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { +namespace intrange { +/// Function that performs inference on an array of `ConstantIntRanges`, +/// abstracted away here to permit writing the function that handles both +/// 64- and 32-bit index types. +using InferRangeFn = + function_ref)>; + +static constexpr unsigned indexMinWidth = 32; +static constexpr unsigned indexMaxWidth = 64; + +enum class CmpMode : uint32_t { Both, Signed, Unsigned }; + +/// Compute `inferFn` on `ranges`, whose size should be the index storage +/// bitwidth. Then, compute the function on `argRanges` again after truncating +/// the ranges to 32 bits. Finally, if the truncation of the 64-bit result is +/// equal to the 32-bit result, use it (to preserve compatibility with folders +/// and inference precision), and take the union of the results otherwise. +/// +/// The `mode` argument specifies if the unsigned, signed, or both results of +/// the inference computation should be used when comparing the results. +ConstantIntRanges inferIndexOp(InferRangeFn inferFn, + ArrayRef argRanges, + CmpMode mode); + +/// Independently zero-extend the unsigned values and sign-extend the signed +/// values in `range` to `destWidth` bits, returning the resulting range. +ConstantIntRanges extRange(const ConstantIntRanges &range, unsigned destWidth); + +/// Use the unsigned values in `range` to zero-extend it to `destWidth`. +ConstantIntRanges extUIRange(const ConstantIntRanges &range, + unsigned destWidth); + +/// Use the signed values in `range` to sign-extend it to `destWidth`. +ConstantIntRanges extSIRange(const ConstantIntRanges &range, + unsigned destWidth); + +/// Truncate `range` to `destWidth` bits, taking care to handle cases such as +/// the truncation of [255, 256] to i8 not being a uniform range. +ConstantIntRanges truncRange(const ConstantIntRanges &range, + unsigned destWidth); + +ConstantIntRanges inferAdd(ArrayRef argRanges); + +ConstantIntRanges inferSub(ArrayRef argRanges); + +ConstantIntRanges inferMul(ArrayRef argRanges); + +ConstantIntRanges inferDivS(ArrayRef argRanges); + +ConstantIntRanges inferDivU(ArrayRef argRanges); + +ConstantIntRanges inferCeilDivS(ArrayRef argRanges); + +ConstantIntRanges inferCeilDivU(ArrayRef argRanges); + +ConstantIntRanges inferFloorDivS(ArrayRef argRanges); + +ConstantIntRanges inferRemS(ArrayRef argRanges); + +ConstantIntRanges inferRemU(ArrayRef argRanges); + +ConstantIntRanges inferMaxS(ArrayRef argRanges); + +ConstantIntRanges inferMaxU(ArrayRef argRanges); + +ConstantIntRanges inferMinS(ArrayRef argRanges); + +ConstantIntRanges inferMinU(ArrayRef argRanges); + +ConstantIntRanges inferAnd(ArrayRef argRanges); + +ConstantIntRanges inferOr(ArrayRef argRanges); + +ConstantIntRanges inferXor(ArrayRef argRanges); + +ConstantIntRanges inferShl(ArrayRef argRanges); + +ConstantIntRanges inferShrS(ArrayRef argRanges); + +ConstantIntRanges inferShrU(ArrayRef argRanges); + +/// Copy of the enum from `arith` and `index` to allow the common integer range +/// infrastructure to not depend on either dialect. +enum class CmpPredicate : uint64_t { + eq, + ne, + slt, + sle, + sgt, + sge, + ult, + ule, + ugt, + uge, +}; + +/// Returns a boolean value if `pred` is statically true or false for +/// anypossible inputs falling within `lhs` and `rhs`, and std::nullopt if the +/// value of the predicate cannot be determined. +Optional evaluatePred(CmpPredicate pred, const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs); + +} // namespace intrange +} // namespace mlir + +#endif // MLIR_INTERFACES_UTILS_INFERINTRANGECOMMON_H diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -16,6 +16,7 @@ LINK_LIBS PUBLIC MLIRDialect + MLIRInferIntRangeCommon MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp --- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "llvm/Support/Debug.h" #include @@ -16,48 +17,7 @@ using namespace mlir; using namespace mlir::arith; - -/// Function that evaluates the result of doing something on arithmetic -/// constants and returns std::nullopt on overflow. -using ConstArithFn = - function_ref(const APInt &, const APInt &)>; - -/// Return the maxmially wide signed or unsigned range for a given bitwidth. - -/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, -/// If either computation overflows, make the result unbounded. -static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, - const APInt &minRight, - const APInt &maxLeft, - const APInt &maxRight, bool isSigned) { - std::optional maybeMin = op(minLeft, minRight); - std::optional maybeMax = op(maxLeft, maxRight); - if (maybeMin && maybeMax) - return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); - return ConstantIntRanges::maxRange(minLeft.getBitWidth()); -} - -/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, -/// ignoring unbounded values. Returns the maximal range if `op` overflows. -static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef lhs, - ArrayRef rhs, bool isSigned) { - unsigned width = lhs[0].getBitWidth(); - APInt min = - isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); - APInt max = - isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); - for (const APInt &left : lhs) { - for (const APInt &right : rhs) { - std::optional maybeThisResult = op(left, right); - if (!maybeThisResult) - return ConstantIntRanges::maxRange(width); - APInt result = std::move(*maybeThisResult); - min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; - max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; - } - } - return ConstantIntRanges::range(min, max, isSigned); -} +using namespace mlir::intrange; //===----------------------------------------------------------------------===// // ConstantOp @@ -78,25 +38,7 @@ void arith::AddIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - ConstArithFn uadd = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.uadd_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn sadd = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.sadd_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - - ConstantIntRanges urange = computeBoundsBy( - uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); - ConstantIntRanges srange = computeBoundsBy( - sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); - setResultRange(getResult(), urange.intersection(srange)); + setResultRange(getResult(), inferAdd(argRanges)); } //===----------------------------------------------------------------------===// @@ -105,25 +47,7 @@ void arith::SubIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn usub = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.usub_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn ssub = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.ssub_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstantIntRanges urange = computeBoundsBy( - usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); - ConstantIntRanges srange = computeBoundsBy( - ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); - setResultRange(getResult(), urange.intersection(srange)); + setResultRange(getResult(), inferSub(argRanges)); } //===----------------------------------------------------------------------===// @@ -132,96 +56,25 @@ void arith::MulIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn umul = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.umul_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - ConstArithFn smul = [](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.smul_ov(b, overflowed); - return overflowed ? std::optional() : result; - }; - - ConstantIntRanges urange = - minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/false); - ConstantIntRanges srange = - minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, - /*isSigned=*/true); - - setResultRange(getResult(), urange.intersection(srange)); + setResultRange(getResult(), inferMul(argRanges)); } //===----------------------------------------------------------------------===// // DivUIOp //===----------------------------------------------------------------------===// -/// Fix up division results (ex. for ceiling and floor), returning an APInt -/// if there has been no overflow -using DivisionFixupFn = function_ref( - const APInt &lhs, const APInt &rhs, const APInt &result)>; - -static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs, - DivisionFixupFn fixup) { - const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), - &rhsMax = rhs.umax(); - - if (!rhsMin.isZero()) { - auto udiv = [&fixup](const APInt &a, - const APInt &b) -> std::optional { - return fixup(a, b, a.udiv(b)); - }; - return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, - /*isSigned=*/false); - } - // Otherwise, it's possible we might divide by 0. - return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); -} - void arith::DivUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferDivUIRange(argRanges[0], argRanges[1], - [](const APInt &lhs, const APInt &rhs, - const APInt &result) { return result; })); + setResultRange(getResult(), inferDivU(argRanges)); } //===----------------------------------------------------------------------===// // DivSIOp //===----------------------------------------------------------------------===// -static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs, - DivisionFixupFn fixup) { - const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), - &rhsMax = rhs.smax(); - bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); - - if (canDivide) { - auto sdiv = [&fixup](const APInt &a, - const APInt &b) -> std::optional { - bool overflowed = false; - APInt result = a.sdiv_ov(b, overflowed); - return overflowed ? std::optional() : fixup(a, b, result); - }; - return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, - /*isSigned=*/true); - } - return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); -} - void arith::DivSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferDivSIRange(argRanges[0], argRanges[1], - [](const APInt &lhs, const APInt &rhs, - const APInt &result) { return result; })); + setResultRange(getResult(), inferDivS(argRanges)); } //===----------------------------------------------------------------------===// @@ -230,20 +83,7 @@ void arith::CeilDivUIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn ceilDivUIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.urem(rhs).isZero()) { - bool overflowed = false; - APInt corrected = - result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix)); + setResultRange(getResult(), inferCeilDivU(argRanges)); } //===----------------------------------------------------------------------===// @@ -252,20 +92,7 @@ void arith::CeilDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn ceilDivSIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { - bool overflowed = false; - APInt corrected = - result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix)); + setResultRange(getResult(), inferCeilDivS(argRanges)); } //===----------------------------------------------------------------------===// @@ -274,20 +101,7 @@ void arith::FloorDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn floorDivSIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { - bool overflowed = false; - APInt corrected = - result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix)); + return setResultRange(getResult(), inferFloorDivS(argRanges)); } //===----------------------------------------------------------------------===// @@ -296,29 +110,7 @@ void arith::RemUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); - - unsigned width = rhsMin.getBitWidth(); - APInt umin = APInt::getZero(width); - APInt umax = APInt::getMaxValue(width); - - if (!rhsMin.isZero()) { - umax = rhsMax - 1; - // Special case: sweeping out a contiguous range in N/[modulus] - if (rhsMin == rhsMax) { - const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); - if ((lhsMax - lhsMin).ult(rhsMax)) { - APInt minRem = lhsMin.urem(rhsMax); - APInt maxRem = lhsMax.urem(rhsMax); - if (minRem.ule(maxRem)) { - umin = minRem; - umax = maxRem; - } - } - } - } - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); + setResultRange(getResult(), inferRemU(argRanges)); } //===----------------------------------------------------------------------===// @@ -327,67 +119,16 @@ void arith::RemSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), - &rhsMax = rhs.smax(); - - unsigned width = rhsMax.getBitWidth(); - APInt smin = APInt::getSignedMinValue(width); - APInt smax = APInt::getSignedMaxValue(width); - // No bounds if zero could be a divisor. - bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); - if (canBound) { - APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); - bool canNegativeDividend = lhsMin.isNegative(); - bool canPositiveDividend = lhsMax.isStrictlyPositive(); - APInt zero = APInt::getZero(maxDivisor.getBitWidth()); - APInt maxPositiveResult = maxDivisor - 1; - APInt minNegativeResult = -maxPositiveResult; - smin = canNegativeDividend ? minNegativeResult : zero; - smax = canPositiveDividend ? maxPositiveResult : zero; - // Special case: sweeping out a contiguous range in N/[modulus]. - if (rhsMin == rhsMax) { - if ((lhsMax - lhsMin).ult(maxDivisor)) { - APInt minRem = lhsMin.srem(maxDivisor); - APInt maxRem = lhsMax.srem(maxDivisor); - if (minRem.sle(maxRem)) { - smin = minRem; - smax = maxRem; - } - } - } - } - setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); + setResultRange(getResult(), inferRemS(argRanges)); } //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// -/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, -/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits -/// that both bonuds have in common. This gives us a consertive approximation -/// for what values can be passed to bitwise operations. -static std::tuple -widenBitwiseBounds(const ConstantIntRanges &bound) { - APInt leftVal = bound.umin(), rightVal = bound.umax(); - unsigned bitwidth = leftVal.getBitWidth(); - unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); - leftVal.clearLowBits(differingBits); - rightVal.setLowBits(differingBits); - return std::make_tuple(std::move(leftVal), std::move(rightVal)); -} - void arith::AndIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto andi = [](const APInt &a, const APInt &b) -> std::optional { - return a & b; - }; - setResultRange(getResult(), - minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false)); + setResultRange(getResult(), inferAnd(argRanges)); } //===----------------------------------------------------------------------===// @@ -396,14 +137,7 @@ void arith::OrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto ori = [](const APInt &a, const APInt &b) -> std::optional { - return a | b; - }; - setResultRange(getResult(), - minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false)); + setResultRange(getResult(), inferOr(argRanges)); } //===----------------------------------------------------------------------===// @@ -412,14 +146,7 @@ void arith::XOrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto xori = [](const APInt &a, const APInt &b) -> std::optional { - return a ^ b; - }; - setResultRange(getResult(), - minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false)); + setResultRange(getResult(), inferXor(argRanges)); } //===----------------------------------------------------------------------===// @@ -428,11 +155,7 @@ void arith::MaxSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); - const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); - setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); + setResultRange(getResult(), inferMaxS(argRanges)); } //===----------------------------------------------------------------------===// @@ -441,11 +164,7 @@ void arith::MaxUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); - const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); + setResultRange(getResult(), inferMaxU(argRanges)); } //===----------------------------------------------------------------------===// @@ -454,11 +173,7 @@ void arith::MinSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); - const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); - setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); + setResultRange(getResult(), inferMinS(argRanges)); } //===----------------------------------------------------------------------===// @@ -467,94 +182,40 @@ void arith::MinUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); - const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); + setResultRange(getResult(), inferMinU(argRanges)); } //===----------------------------------------------------------------------===// // ExtUIOp //===----------------------------------------------------------------------===// -static ConstantIntRanges extUIRange(const ConstantIntRanges &range, - Type destType) { - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - APInt umin = range.umin().zext(destWidth); - APInt umax = range.umax().zext(destWidth); - return ConstantIntRanges::fromUnsigned(umin, umax); -} - void arith::ExtUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - Type destType = getResult().getType(); - setResultRange(getResult(), extUIRange(argRanges[0], destType)); + unsigned destWidth = + ConstantIntRanges::getStorageBitwidth(getResult().getType()); + setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); } //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// -static ConstantIntRanges extSIRange(const ConstantIntRanges &range, - Type destType) { - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - APInt smin = range.smin().sext(destWidth); - APInt smax = range.smax().sext(destWidth); - return ConstantIntRanges::fromSigned(smin, smax); -} - void arith::ExtSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - Type destType = getResult().getType(); - setResultRange(getResult(), extSIRange(argRanges[0], destType)); + unsigned destWidth = + ConstantIntRanges::getStorageBitwidth(getResult().getType()); + setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); } //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// -static ConstantIntRanges truncIRange(const ConstantIntRanges &range, - Type destType) { - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], - // the range of the resulting value is not contiguous ind includes 0. - // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], - // but you can't truncate [255, 257] similarly. - bool hasUnsignedRollover = - range.umin().lshr(destWidth) != range.umax().lshr(destWidth); - APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) - : range.umin().trunc(destWidth); - APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) - : range.umax().trunc(destWidth); - - // Signed post-truncation rollover will not occur when either: - // - The high parts of the min and max, plus the sign bit, are the same - // - The high halves + sign bit of the min and max are either all 1s or all 0s - // and you won't create a [positive, negative] range by truncating. - // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 - // but not [255, 257]_i16 to a range of i8s. You can also truncate - // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. - // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) - // will truncate to 0x7e, which is greater than 0 - APInt sminHighPart = range.smin().ashr(destWidth - 1); - APInt smaxHighPart = range.smax().ashr(destWidth - 1); - bool hasSignedOverflow = - (sminHighPart != smaxHighPart) && - !(sminHighPart.isAllOnes() && - (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && - !(sminHighPart.isZero() && smaxHighPart.isZero()); - APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) - : range.smin().trunc(destWidth); - APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) - : range.smax().trunc(destWidth); - return {umin, umax, smin, smax}; -} - void arith::TruncIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - Type destType = getResult().getType(); - setResultRange(getResult(), truncIRange(argRanges[0], destType)); + unsigned destWidth = + ConstantIntRanges::getStorageBitwidth(getResult().getType()); + setResultRange(getResult(), truncRange(argRanges[0], destWidth)); } //===----------------------------------------------------------------------===// @@ -569,9 +230,9 @@ unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (srcWidth < destWidth) - setResultRange(getResult(), extSIRange(argRanges[0], destType)); + setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); else if (srcWidth > destWidth) - setResultRange(getResult(), truncIRange(argRanges[0], destType)); + setResultRange(getResult(), truncRange(argRanges[0], destWidth)); else setResultRange(getResult(), argRanges[0]); } @@ -588,9 +249,9 @@ unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (srcWidth < destWidth) - setResultRange(getResult(), extUIRange(argRanges[0], destType)); + setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); else if (srcWidth > destWidth) - setResultRange(getResult(), truncIRange(argRanges[0], destType)); + setResultRange(getResult(), truncRange(argRanges[0], destWidth)); else setResultRange(getResult(), argRanges[0]); } @@ -599,51 +260,19 @@ // CmpIOp //===----------------------------------------------------------------------===// -bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs) { - switch (pred) { - case arith::CmpIPredicate::sle: - case arith::CmpIPredicate::slt: - return (applyCmpPredicate(pred, lhs.smax(), rhs.smin())); - case arith::CmpIPredicate::ule: - case arith::CmpIPredicate::ult: - return applyCmpPredicate(pred, lhs.umax(), rhs.umin()); - case arith::CmpIPredicate::sge: - case arith::CmpIPredicate::sgt: - return applyCmpPredicate(pred, lhs.smin(), rhs.smax()); - case arith::CmpIPredicate::uge: - case arith::CmpIPredicate::ugt: - return applyCmpPredicate(pred, lhs.umin(), rhs.umax()); - case arith::CmpIPredicate::eq: { - std::optional lhsConst = lhs.getConstantValue(); - std::optional rhsConst = rhs.getConstantValue(); - return lhsConst && rhsConst && lhsConst == rhsConst; - } - case arith::CmpIPredicate::ne: { - // While equality requires that there is an interpration of the preceeding - // computations that produces equal constants, whether that be signed or - // unsigned, statically determining inequality requires that neither - // interpretation produce potentially overlapping ranges. - bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) || - isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs); - bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) || - isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs); - return sne && une; - } - } - return false; -} - void arith::CmpIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - arith::CmpIPredicate pred = getPredicate(); + arith::CmpIPredicate arithPred = getPredicate(); + intrange::CmpPredicate pred = static_cast(arithPred); const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; APInt min = APInt::getZero(1); APInt max = APInt::getAllOnesValue(1); - if (isStaticallyTrue(pred, lhs, rhs)) + + Optional truthValue = intrange::evaluatePred(pred, lhs, rhs); + if (truthValue.has_value() && *truthValue) min = max; - else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) + else if (truthValue.has_value() && !(*truthValue)) max = min; setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); @@ -673,18 +302,7 @@ void arith::ShLIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - ConstArithFn shl = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.shl(r); - }; - ConstantIntRanges urange = - minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/false); - ConstantIntRanges srange = - minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, - /*isSigned=*/true); - setResultRange(getResult(), urange.intersection(srange)); + setResultRange(getResult(), inferShl(argRanges)); } //===----------------------------------------------------------------------===// @@ -693,15 +311,7 @@ void arith::ShRUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn lshr = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.lshr(r); - }; - setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()}, - {rhs.umin(), rhs.umax()}, - /*isSigned=*/false)); + setResultRange(getResult(), inferShrU(argRanges)); } //===----------------------------------------------------------------------===// @@ -710,14 +320,5 @@ void arith::ShRSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - ConstArithFn ashr = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.ashr(r); - }; - - setResultRange(getResult(), - minMaxBy(ashr, {lhs.smin(), lhs.smax()}, - {rhs.umin(), rhs.umax()}, /*isSigned=*/true)); + setResultRange(getResult(), inferShrS(argRanges)); } diff --git a/mlir/lib/Dialect/Index/IR/CMakeLists.txt b/mlir/lib/Dialect/Index/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Index/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Index/IR/CMakeLists.txt @@ -2,6 +2,7 @@ IndexAttrs.cpp IndexDialect.cpp IndexOps.cpp + InferIntRangeInterfaceImpls.cpp DEPENDS MLIRIndexOpsIncGen @@ -10,6 +11,8 @@ MLIRDialect MLIRIR MLIRCastInterfaces + MLIRInferIntRangeCommon + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp @@ -0,0 +1,252 @@ +//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "int-range-analysis" + +using namespace mlir; +using namespace mlir::index; +using namespace mlir::intrange; + +//===----------------------------------------------------------------------===// +// Constants +//===----------------------------------------------------------------------===// + +void ConstantOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const APInt &value = getValue(); + setResultRange(getResult(), ConstantIntRanges::constant(value)); +} + +void BoolConstantOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + bool value = getValue(); + APInt asInt(/*numBits=*/1, value); + setResultRange(getResult(), ConstantIntRanges::constant(asInt)); +} + +//===----------------------------------------------------------------------===// +// Arithmec operations. All of these operations will have their results inferred +// using both the 64-bit values and truncated 32-bit values of their inputs, +// with the results being the union of those inferences, except where the +// truncation of the 64-bit result is equal to the 32-bit result (at which time +// we take the 64-bit result). +//===----------------------------------------------------------------------===// + +void AddOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both)); +} + +void SubOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both)); +} + +void MulOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both)); +} + +void DivUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned)); +} + +void DivSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferDivS, argRanges, CmpMode::Signed)); +} + +void CeilDivUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned)); +} + +void CeilDivSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed)); +} + +void FloorDivSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + return setResultRange( + getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed)); +} + +void RemSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferRemS, argRanges, CmpMode::Signed)); +} + +void RemUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned)); +} + +void MaxSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferMaxS, argRanges, CmpMode::Signed)); +} + +void MaxUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned)); +} + +void MinSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferMinS, argRanges, CmpMode::Signed)); +} + +void MinUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned)); +} + +void ShlOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both)); +} + +void ShrSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferShrS, argRanges, CmpMode::Signed)); +} + +void ShrUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned)); +} + +void AndOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned)); +} + +void OrOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferOr, argRanges, CmpMode::Unsigned)); +} + +void XOrOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferIndexOp(inferXor, argRanges, CmpMode::Unsigned)); +} + +//===----------------------------------------------------------------------===// +// Casts +//===----------------------------------------------------------------------===// + +static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range, + unsigned srcWidth, unsigned destWidth, + bool isSigned) { + if (srcWidth < destWidth) + return isSigned ? extSIRange(range, destWidth) + : extUIRange(range, destWidth); + if (srcWidth > destWidth) + return truncRange(range, destWidth); + return range; +} + +// When casting to `index`, we will take the union of the possible fixed-width +// casts. +static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range, + Type sourceType, Type destType, + bool isSigned) { + unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + if (sourceType.isIndex()) + return makeLikeDest(range, srcWidth, destWidth, isSigned); + // We are casting to indexs, so use the union of the 32-bit and 64-bit casts + ConstantIntRanges storageRange = + makeLikeDest(range, srcWidth, destWidth, isSigned); + ConstantIntRanges minWidthRange = + makeLikeDest(range, srcWidth, indexMinWidth, isSigned); + ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth); + ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt); + return ret; +} + +void CastSOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, + /*isSigned=*/true)); +} + +void CastUOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType, + /*isSigned=*/false)); +} + +//===----------------------------------------------------------------------===// +// CmpOp +//===----------------------------------------------------------------------===// + +void CmpOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + index::IndexCmpPredicate indexPred = getPred(); + intrange::CmpPredicate pred = static_cast(indexPred); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + APInt min = APInt::getZero(1); + APInt max = APInt::getAllOnesValue(1); + + Optional truthValue64 = intrange::evaluatePred(pred, lhs, rhs); + + ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth), + rhsTrunc = truncRange(rhs, indexMinWidth); + Optional truthValue32 = + intrange::evaluatePred(pred, lhsTrunc, rhsTrunc); + + if (truthValue64 == truthValue32) { + if (truthValue64.has_value() && *truthValue64) + min = max; + else if (truthValue64.has_value() && !(*truthValue64)) + max = min; + } + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); +} + +//===----------------------------------------------------------------------===// +// SizeOf, which is bounded between the two supported bitwidth (32 and 64). +//===----------------------------------------------------------------------===// + +void SizeOfOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + unsigned storageWidth = + ConstantIntRanges::getStorageBitwidth(getResult().getType()); + APInt min(/*numBits=*/storageWidth, indexMinWidth); + APInt max(/*numBits=*/storageWidth, indexMaxWidth); + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -51,3 +51,5 @@ add_mlir_interface_library(TilingInterface) add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(ViewLikeInterface) + +add_subdirectory(Utils) diff --git a/mlir/lib/Interfaces/Utils/CMakeLists.txt b/mlir/lib/Interfaces/Utils/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/Utils/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_library(MLIRInferIntRangeCommon + InferIntRangeCommon.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces/Utils + + DEPENDS + MLIRInferIntRangeInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRInferIntRangeInterface + MLIRIR +) diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp copy from mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp copy to mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp --- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp @@ -1,29 +1,41 @@ -//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// +//===- InferIntRangeCommon.cpp - Inference for common ops ------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// This file contains implementations of range inference for operations that are +// common to both the `arith` and `index` dialects to facilitate reuse. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Interfaces/InferIntRangeInterface.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" + #include "llvm/Support/Debug.h" + +#include #include +using namespace mlir; + #define DEBUG_TYPE "int-range-analysis" -using namespace mlir; -using namespace mlir::arith; +//===----------------------------------------------------------------------===// +// General utilities +//===----------------------------------------------------------------------===// /// Function that evaluates the result of doing something on arithmetic /// constants and returns std::nullopt on overflow. using ConstArithFn = function_ref(const APInt &, const APInt &)>; -/// Return the maxmially wide signed or unsigned range for a given bitwidth. - /// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, /// If either computation overflows, make the result unbounded. static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, @@ -60,24 +72,113 @@ } //===----------------------------------------------------------------------===// -// ConstantOp -//===----------------------------------------------------------------------===// - -void arith::ConstantOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - auto constAttr = getValue().dyn_cast_or_null(); - if (constAttr) { - const APInt &value = constAttr.getValue(); - setResultRange(getResult(), ConstantIntRanges::constant(value)); +// Ext, trunc, index op handling +//===----------------------------------------------------------------------===// + +ConstantIntRanges +mlir::intrange::inferIndexOp(InferRangeFn inferFn, + ArrayRef argRanges, + intrange::CmpMode mode) { + ConstantIntRanges sixtyFour = inferFn(argRanges); + SmallVector truncated; + llvm::transform(argRanges, std::back_inserter(truncated), + [](const ConstantIntRanges &range) { + return truncRange(range, /*destWidth=*/indexMinWidth); + }); + ConstantIntRanges thirtyTwo = inferFn(truncated); + ConstantIntRanges thirtyTwoAsSixtyFour = + extRange(thirtyTwo, /*destWidth=*/indexMaxWidth); + ConstantIntRanges sixtyFourAsThirtyTwo = + truncRange(sixtyFour, /*destWidth=*/indexMinWidth); + + LLVM_DEBUG(llvm::dbgs() << "Index handling: 64-bit result = " << sixtyFour + << " 32-bit = " << thirtyTwo << "\n"); + bool truncEqual = false; + switch (mode) { + case intrange::CmpMode::Both: + truncEqual = (thirtyTwo == sixtyFourAsThirtyTwo); + break; + case intrange::CmpMode::Signed: + truncEqual = (thirtyTwo.smin() == sixtyFourAsThirtyTwo.smin() && + thirtyTwo.smax() == sixtyFourAsThirtyTwo.smax()); + break; + case intrange::CmpMode::Unsigned: + truncEqual = (thirtyTwo.umin() == sixtyFourAsThirtyTwo.umin() && + thirtyTwo.umax() == sixtyFourAsThirtyTwo.umax()); + break; } + if (truncEqual) + // Returing the 64-bit result preserves more information. + return sixtyFour; + ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour); + return merged; +} + +ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range, + unsigned int destWidth) { + APInt umin = range.umin().zext(destWidth); + APInt umax = range.umax().zext(destWidth); + APInt smin = range.smin().sext(destWidth); + APInt smax = range.smax().sext(destWidth); + return {umin, umax, smin, smax}; +} + +ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range, + unsigned destWidth) { + APInt umin = range.umin().zext(destWidth); + APInt umax = range.umax().zext(destWidth); + return ConstantIntRanges::fromUnsigned(umin, umax); +} + +ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range, + unsigned destWidth) { + APInt smin = range.smin().sext(destWidth); + APInt smax = range.smax().sext(destWidth); + return ConstantIntRanges::fromSigned(smin, smax); +} + +ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range, + unsigned int destWidth) { + // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], + // the range of the resulting value is not contiguous ind includes 0. + // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], + // but you can't truncate [255, 257] similarly. + bool hasUnsignedRollover = + range.umin().lshr(destWidth) != range.umax().lshr(destWidth); + APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) + : range.umin().trunc(destWidth); + APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) + : range.umax().trunc(destWidth); + + // Signed post-truncation rollover will not occur when either: + // - The high parts of the min and max, plus the sign bit, are the same + // - The high halves + sign bit of the min and max are either all 1s or all 0s + // and you won't create a [positive, negative] range by truncating. + // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 + // but not [255, 257]_i16 to a range of i8s. You can also truncate + // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. + // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) + // will truncate to 0x7e, which is greater than 0 + APInt sminHighPart = range.smin().ashr(destWidth - 1); + APInt smaxHighPart = range.smax().ashr(destWidth - 1); + bool hasSignedOverflow = + (sminHighPart != smaxHighPart) && + !(sminHighPart.isAllOnes() && + (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && + !(sminHighPart.isZero() && smaxHighPart.isZero()); + APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) + : range.smin().trunc(destWidth); + APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) + : range.smax().trunc(destWidth); + return {umin, umax, smin, smax}; } //===----------------------------------------------------------------------===// -// AddIOp +// Addition //===----------------------------------------------------------------------===// -void arith::AddIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { +ConstantIntRanges +mlir::intrange::inferAdd(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; ConstArithFn uadd = [](const APInt &a, const APInt &b) -> std::optional { @@ -96,15 +197,15 @@ uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); ConstantIntRanges srange = computeBoundsBy( sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); - setResultRange(getResult(), urange.intersection(srange)); + return urange.intersection(srange); } //===----------------------------------------------------------------------===// -// SubIOp +// Subtraction //===----------------------------------------------------------------------===// -void arith::SubIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { +ConstantIntRanges +mlir::intrange::inferSub(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; ConstArithFn usub = [](const APInt &a, @@ -123,15 +224,15 @@ usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); ConstantIntRanges srange = computeBoundsBy( ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); - setResultRange(getResult(), urange.intersection(srange)); + return urange.intersection(srange); } //===----------------------------------------------------------------------===// -// MulIOp +// Multiplication //===----------------------------------------------------------------------===// -void arith::MulIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { +ConstantIntRanges +mlir::intrange::inferMul(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; ConstArithFn umul = [](const APInt &a, @@ -153,12 +254,11 @@ ConstantIntRanges srange = minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, /*isSigned=*/true); - - setResultRange(getResult(), urange.intersection(srange)); + return urange.intersection(srange); } //===----------------------------------------------------------------------===// -// DivUIOp +// DivU, CeilDivU (Unsigned division) //===----------------------------------------------------------------------===// /// Fix up division results (ex. for ceiling and floor), returning an APInt @@ -166,9 +266,9 @@ using DivisionFixupFn = function_ref( const APInt &lhs, const APInt &rhs, const APInt &result)>; -static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs, - DivisionFixupFn fixup) { +static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs, + DivisionFixupFn fixup) { const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); @@ -184,21 +284,38 @@ return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); } -void arith::DivUIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferDivUIRange(argRanges[0], argRanges[1], - [](const APInt &lhs, const APInt &rhs, - const APInt &result) { return result; })); +ConstantIntRanges +mlir::intrange::inferDivU(ArrayRef argRanges) { + return inferDivURange(argRanges[0], argRanges[1], + [](const APInt &lhs, const APInt &rhs, + const APInt &result) { return result; }); +} + +ConstantIntRanges +mlir::intrange::inferCeilDivU(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn ceilDivUIFix = + [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> std::optional { + if (!lhs.urem(rhs).isZero()) { + bool overflowed = false; + APInt corrected = + result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? std::optional() : corrected; + } + return result; + }; + return inferDivURange(lhs, rhs, ceilDivUIFix); } //===----------------------------------------------------------------------===// -// DivSIOp +// DivS, CeilDivS, FloorDivS (Signed division) //===----------------------------------------------------------------------===// -static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs, - DivisionFixupFn fixup) { +static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs, + DivisionFixupFn fixup) { const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), &rhsMax = rhs.smax(); bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); @@ -216,42 +333,15 @@ return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); } -void arith::DivSIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - setResultRange(getResult(), - inferDivSIRange(argRanges[0], argRanges[1], - [](const APInt &lhs, const APInt &rhs, - const APInt &result) { return result; })); +ConstantIntRanges +mlir::intrange::inferDivS(ArrayRef argRanges) { + return inferDivSRange(argRanges[0], argRanges[1], + [](const APInt &lhs, const APInt &rhs, + const APInt &result) { return result; }); } -//===----------------------------------------------------------------------===// -// CeilDivUIOp -//===----------------------------------------------------------------------===// - -void arith::CeilDivUIOp::inferResultRanges( - ArrayRef argRanges, SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - DivisionFixupFn ceilDivUIFix = - [](const APInt &lhs, const APInt &rhs, - const APInt &result) -> std::optional { - if (!lhs.urem(rhs).isZero()) { - bool overflowed = false; - APInt corrected = - result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); - return overflowed ? std::optional() : corrected; - } - return result; - }; - setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix)); -} - -//===----------------------------------------------------------------------===// -// CeilDivSIOp -//===----------------------------------------------------------------------===// - -void arith::CeilDivSIOp::inferResultRanges( - ArrayRef argRanges, SetIntRangeFn setResultRange) { +ConstantIntRanges +mlir::intrange::inferCeilDivS(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; DivisionFixupFn ceilDivSIFix = @@ -265,15 +355,11 @@ } return result; }; - setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix)); + return inferDivSRange(lhs, rhs, ceilDivSIFix); } -//===----------------------------------------------------------------------===// -// FloorDivSIOp -//===----------------------------------------------------------------------===// - -void arith::FloorDivSIOp::inferResultRanges( - ArrayRef argRanges, SetIntRangeFn setResultRange) { +ConstantIntRanges +mlir::intrange::inferFloorDivS(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; DivisionFixupFn floorDivSIFix = @@ -287,46 +373,15 @@ } return result; }; - setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix)); -} - -//===----------------------------------------------------------------------===// -// RemUIOp -//===----------------------------------------------------------------------===// - -void arith::RemUIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); - - unsigned width = rhsMin.getBitWidth(); - APInt umin = APInt::getZero(width); - APInt umax = APInt::getMaxValue(width); - - if (!rhsMin.isZero()) { - umax = rhsMax - 1; - // Special case: sweeping out a contiguous range in N/[modulus] - if (rhsMin == rhsMax) { - const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); - if ((lhsMax - lhsMin).ult(rhsMax)) { - APInt minRem = lhsMin.urem(rhsMax); - APInt maxRem = lhsMax.urem(rhsMax); - if (minRem.ule(maxRem)) { - umin = minRem; - umax = maxRem; - } - } - } - } - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); + return inferDivSRange(lhs, rhs, floorDivSIFix); } //===----------------------------------------------------------------------===// -// RemSIOp +// Signed remainder (RemS) //===----------------------------------------------------------------------===// -void arith::RemSIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { +ConstantIntRanges +mlir::intrange::inferRemS(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), &rhsMax = rhs.smax(); @@ -357,322 +412,137 @@ } } } - setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); -} - -//===----------------------------------------------------------------------===// -// AndIOp -//===----------------------------------------------------------------------===// - -/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, -/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits -/// that both bonuds have in common. This gives us a consertive approximation -/// for what values can be passed to bitwise operations. -static std::tuple -widenBitwiseBounds(const ConstantIntRanges &bound) { - APInt leftVal = bound.umin(), rightVal = bound.umax(); - unsigned bitwidth = leftVal.getBitWidth(); - unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); - leftVal.clearLowBits(differingBits); - rightVal.setLowBits(differingBits); - return std::make_tuple(std::move(leftVal), std::move(rightVal)); -} - -void arith::AndIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto andi = [](const APInt &a, const APInt &b) -> std::optional { - return a & b; - }; - setResultRange(getResult(), - minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false)); + return ConstantIntRanges::fromSigned(smin, smax); } //===----------------------------------------------------------------------===// -// OrIOp +// Unsigned remainder (RemU) //===----------------------------------------------------------------------===// -void arith::OrIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto ori = [](const APInt &a, const APInt &b) -> std::optional { - return a | b; - }; - setResultRange(getResult(), - minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false)); -} +ConstantIntRanges +mlir::intrange::inferRemU(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); -//===----------------------------------------------------------------------===// -// XOrIOp -//===----------------------------------------------------------------------===// + unsigned width = rhsMin.getBitWidth(); + APInt umin = APInt::getZero(width); + APInt umax = APInt::getMaxValue(width); -void arith::XOrIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); - auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); - auto xori = [](const APInt &a, const APInt &b) -> std::optional { - return a ^ b; - }; - setResultRange(getResult(), - minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, - /*isSigned=*/false)); + if (!rhsMin.isZero()) { + umax = rhsMax - 1; + // Special case: sweeping out a contiguous range in N/[modulus] + if (rhsMin == rhsMax) { + const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); + if ((lhsMax - lhsMin).ult(rhsMax)) { + APInt minRem = lhsMin.urem(rhsMax); + APInt maxRem = lhsMax.urem(rhsMax); + if (minRem.ule(maxRem)) { + umin = minRem; + umax = maxRem; + } + } + } + } + return ConstantIntRanges::fromUnsigned(umin, umax); } //===----------------------------------------------------------------------===// -// MaxSIOp +// Max and min (MaxS, MaxU, MinS, MinU) //===----------------------------------------------------------------------===// -void arith::MaxSIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { +ConstantIntRanges +mlir::intrange::inferMaxS(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); - setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); + return ConstantIntRanges::fromSigned(smin, smax); } -//===----------------------------------------------------------------------===// -// MaxUIOp -//===----------------------------------------------------------------------===// - -void arith::MaxUIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { +ConstantIntRanges +mlir::intrange::inferMaxU(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); + return ConstantIntRanges::fromUnsigned(umin, umax); } -//===----------------------------------------------------------------------===// -// MinSIOp -//===----------------------------------------------------------------------===// - -void arith::MinSIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { +ConstantIntRanges +mlir::intrange::inferMinS(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); - setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); + return ConstantIntRanges::fromSigned(smin, smax); } -//===----------------------------------------------------------------------===// -// MinUIOp -//===----------------------------------------------------------------------===// - -void arith::MinUIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { +ConstantIntRanges +mlir::intrange::inferMinU(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); -} - -//===----------------------------------------------------------------------===// -// ExtUIOp -//===----------------------------------------------------------------------===// - -static ConstantIntRanges extUIRange(const ConstantIntRanges &range, - Type destType) { - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - APInt umin = range.umin().zext(destWidth); - APInt umax = range.umax().zext(destWidth); return ConstantIntRanges::fromUnsigned(umin, umax); } -void arith::ExtUIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - Type destType = getResult().getType(); - setResultRange(getResult(), extUIRange(argRanges[0], destType)); -} - -//===----------------------------------------------------------------------===// -// ExtSIOp -//===----------------------------------------------------------------------===// - -static ConstantIntRanges extSIRange(const ConstantIntRanges &range, - Type destType) { - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - APInt smin = range.smin().sext(destWidth); - APInt smax = range.smax().sext(destWidth); - return ConstantIntRanges::fromSigned(smin, smax); -} - -void arith::ExtSIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - Type destType = getResult().getType(); - setResultRange(getResult(), extSIRange(argRanges[0], destType)); -} - -//===----------------------------------------------------------------------===// -// TruncIOp -//===----------------------------------------------------------------------===// - -static ConstantIntRanges truncIRange(const ConstantIntRanges &range, - Type destType) { - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], - // the range of the resulting value is not contiguous ind includes 0. - // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], - // but you can't truncate [255, 257] similarly. - bool hasUnsignedRollover = - range.umin().lshr(destWidth) != range.umax().lshr(destWidth); - APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) - : range.umin().trunc(destWidth); - APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) - : range.umax().trunc(destWidth); - - // Signed post-truncation rollover will not occur when either: - // - The high parts of the min and max, plus the sign bit, are the same - // - The high halves + sign bit of the min and max are either all 1s or all 0s - // and you won't create a [positive, negative] range by truncating. - // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 - // but not [255, 257]_i16 to a range of i8s. You can also truncate - // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. - // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) - // will truncate to 0x7e, which is greater than 0 - APInt sminHighPart = range.smin().ashr(destWidth - 1); - APInt smaxHighPart = range.smax().ashr(destWidth - 1); - bool hasSignedOverflow = - (sminHighPart != smaxHighPart) && - !(sminHighPart.isAllOnes() && - (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && - !(sminHighPart.isZero() && smaxHighPart.isZero()); - APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) - : range.smin().trunc(destWidth); - APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) - : range.smax().trunc(destWidth); - return {umin, umax, smin, smax}; -} - -void arith::TruncIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - Type destType = getResult().getType(); - setResultRange(getResult(), truncIRange(argRanges[0], destType)); -} - -//===----------------------------------------------------------------------===// -// IndexCastOp -//===----------------------------------------------------------------------===// - -void arith::IndexCastOp::inferResultRanges( - ArrayRef argRanges, SetIntRangeFn setResultRange) { - Type sourceType = getOperand().getType(); - Type destType = getResult().getType(); - unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - - if (srcWidth < destWidth) - setResultRange(getResult(), extSIRange(argRanges[0], destType)); - else if (srcWidth > destWidth) - setResultRange(getResult(), truncIRange(argRanges[0], destType)); - else - setResultRange(getResult(), argRanges[0]); -} - //===----------------------------------------------------------------------===// -// IndexCastUIOp +// Bitwise operators (And, Or, Xor) //===----------------------------------------------------------------------===// -void arith::IndexCastUIOp::inferResultRanges( - ArrayRef argRanges, SetIntRangeFn setResultRange) { - Type sourceType = getOperand().getType(); - Type destType = getResult().getType(); - unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); - unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - - if (srcWidth < destWidth) - setResultRange(getResult(), extUIRange(argRanges[0], destType)); - else if (srcWidth > destWidth) - setResultRange(getResult(), truncIRange(argRanges[0], destType)); - else - setResultRange(getResult(), argRanges[0]); +/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, +/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits +/// that both bonuds have in common. This gives us a consertive approximation +/// for what values can be passed to bitwise operations. +static std::tuple +widenBitwiseBounds(const ConstantIntRanges &bound) { + APInt leftVal = bound.umin(), rightVal = bound.umax(); + unsigned bitwidth = leftVal.getBitWidth(); + unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); + leftVal.clearLowBits(differingBits); + rightVal.setLowBits(differingBits); + return std::make_tuple(std::move(leftVal), std::move(rightVal)); } -//===----------------------------------------------------------------------===// -// CmpIOp -//===----------------------------------------------------------------------===// - -bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs, - const ConstantIntRanges &rhs) { - switch (pred) { - case arith::CmpIPredicate::sle: - case arith::CmpIPredicate::slt: - return (applyCmpPredicate(pred, lhs.smax(), rhs.smin())); - case arith::CmpIPredicate::ule: - case arith::CmpIPredicate::ult: - return applyCmpPredicate(pred, lhs.umax(), rhs.umin()); - case arith::CmpIPredicate::sge: - case arith::CmpIPredicate::sgt: - return applyCmpPredicate(pred, lhs.smin(), rhs.smax()); - case arith::CmpIPredicate::uge: - case arith::CmpIPredicate::ugt: - return applyCmpPredicate(pred, lhs.umin(), rhs.umax()); - case arith::CmpIPredicate::eq: { - std::optional lhsConst = lhs.getConstantValue(); - std::optional rhsConst = rhs.getConstantValue(); - return lhsConst && rhsConst && lhsConst == rhsConst; - } - case arith::CmpIPredicate::ne: { - // While equality requires that there is an interpration of the preceeding - // computations that produces equal constants, whether that be signed or - // unsigned, statically determining inequality requires that neither - // interpretation produce potentially overlapping ranges. - bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) || - isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs); - bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) || - isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs); - return sne && une; - } - } - return false; +ConstantIntRanges +mlir::intrange::inferAnd(ArrayRef argRanges) { + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto andi = [](const APInt &a, const APInt &b) -> std::optional { + return a & b; + }; + return minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false); } -void arith::CmpIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - arith::CmpIPredicate pred = getPredicate(); - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; - - APInt min = APInt::getZero(1); - APInt max = APInt::getAllOnesValue(1); - if (isStaticallyTrue(pred, lhs, rhs)) - min = max; - else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) - max = min; - - setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); +ConstantIntRanges +mlir::intrange::inferOr(ArrayRef argRanges) { + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto ori = [](const APInt &a, const APInt &b) -> std::optional { + return a | b; + }; + return minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false); } -//===----------------------------------------------------------------------===// -// SelectOp -//===----------------------------------------------------------------------===// - -void arith::SelectOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - std::optional mbCondVal = argRanges[0].getConstantValue(); - - if (mbCondVal) { - if (mbCondVal->isZero()) - setResultRange(getResult(), argRanges[2]); - else - setResultRange(getResult(), argRanges[1]); - return; - } - setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2])); +ConstantIntRanges +mlir::intrange::inferXor(ArrayRef argRanges) { + auto [lhsZeros, lhsOnes] = widenBitwiseBounds(argRanges[0]); + auto [rhsZeros, rhsOnes] = widenBitwiseBounds(argRanges[1]); + auto xori = [](const APInt &a, const APInt &b) -> std::optional { + return a ^ b; + }; + return minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false); } //===----------------------------------------------------------------------===// -// ShLIOp +// Shifts (Shl, ShrS, ShrU) //===----------------------------------------------------------------------===// -void arith::ShLIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { +ConstantIntRanges +mlir::intrange::inferShl(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; ConstArithFn shl = [](const APInt &l, const APInt &r) -> std::optional { @@ -684,40 +554,110 @@ ConstantIntRanges srange = minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, /*isSigned=*/true); - setResultRange(getResult(), urange.intersection(srange)); + return urange.intersection(srange); } -//===----------------------------------------------------------------------===// -// ShRUIOp -//===----------------------------------------------------------------------===// +ConstantIntRanges +mlir::intrange::inferShrS(ArrayRef argRanges) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; -void arith::ShRUIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { + ConstArithFn ashr = [](const APInt &l, + const APInt &r) -> std::optional { + return r.uge(r.getBitWidth()) ? std::optional() : l.ashr(r); + }; + + return minMaxBy(ashr, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/true); +} + +ConstantIntRanges +mlir::intrange::inferShrU(ArrayRef argRanges) { const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; ConstArithFn lshr = [](const APInt &l, const APInt &r) -> std::optional { return r.uge(r.getBitWidth()) ? std::optional() : l.lshr(r); }; - setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()}, - {rhs.umin(), rhs.umax()}, - /*isSigned=*/false)); + return minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/false); } //===----------------------------------------------------------------------===// -// ShRSIOp +// Comparisons (Cmp) //===----------------------------------------------------------------------===// -void arith::ShRSIOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRange) { - const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; +static intrange::CmpPredicate invertPredicate(intrange::CmpPredicate pred) { + switch (pred) { + case intrange::CmpPredicate::eq: + return intrange::CmpPredicate::ne; + case intrange::CmpPredicate::ne: + return intrange::CmpPredicate::eq; + case intrange::CmpPredicate::slt: + return intrange::CmpPredicate::sge; + case intrange::CmpPredicate::sle: + return intrange::CmpPredicate::sgt; + case intrange::CmpPredicate::sgt: + return intrange::CmpPredicate::sle; + case intrange::CmpPredicate::sge: + return intrange::CmpPredicate::slt; + case intrange::CmpPredicate::ult: + return intrange::CmpPredicate::uge; + case intrange::CmpPredicate::ule: + return intrange::CmpPredicate::ugt; + case intrange::CmpPredicate::ugt: + return intrange::CmpPredicate::ule; + case intrange::CmpPredicate::uge: + return intrange::CmpPredicate::ult; + } + llvm_unreachable("unknown cmp predicate value"); +} - ConstArithFn ashr = [](const APInt &l, - const APInt &r) -> std::optional { - return r.uge(r.getBitWidth()) ? std::optional() : l.ashr(r); - }; +static bool isStaticallyTrue(intrange::CmpPredicate pred, + const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs) { + switch (pred) { + case intrange::CmpPredicate::sle: + return lhs.smax().sle(rhs.smin()); + case intrange::CmpPredicate::slt: + return lhs.smax().slt(rhs.smin()); + case intrange::CmpPredicate::ule: + return lhs.umax().ule(rhs.umin()); + case intrange::CmpPredicate::ult: + return lhs.umax().ult(rhs.umin()); + case intrange::CmpPredicate::sge: + return lhs.smin().sge(rhs.smax()); + case intrange::CmpPredicate::sgt: + return lhs.smin().sgt(rhs.smax()); + case intrange::CmpPredicate::uge: + return lhs.umin().uge(rhs.umax()); + case intrange::CmpPredicate::ugt: + return lhs.umin().ugt(rhs.umax()); + case intrange::CmpPredicate::eq: { + std::optional lhsConst = lhs.getConstantValue(); + std::optional rhsConst = rhs.getConstantValue(); + return lhsConst && rhsConst && lhsConst == rhsConst; + } + case intrange::CmpPredicate::ne: { + // While equality requires that there is an interpration of the preceeding + // computations that produces equal constants, whether that be signed or + // unsigned, statically determining inequality requires that neither + // interpretation produce potentially overlapping ranges. + bool sne = isStaticallyTrue(intrange::CmpPredicate::slt, lhs, rhs) || + isStaticallyTrue(intrange::CmpPredicate::sgt, lhs, rhs); + bool une = isStaticallyTrue(intrange::CmpPredicate::ult, lhs, rhs) || + isStaticallyTrue(intrange::CmpPredicate::ugt, lhs, rhs); + return sne && une; + } + } + return false; +} - setResultRange(getResult(), - minMaxBy(ashr, {lhs.smin(), lhs.smax()}, - {rhs.umin(), rhs.umax()}, /*isSigned=*/true)); +std::optional mlir::intrange::evaluatePred(CmpPredicate pred, + const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs) { + if (isStaticallyTrue(pred, lhs, rhs)) + return true; + if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) + return false; + return std::nullopt; } diff --git a/mlir/test/Dialect/Index/int-range-inference.mlir b/mlir/test/Dialect/Index/int-range-inference.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Index/int-range-inference.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s + +// Most operations are covered by the `arith` tests, which use the same code +// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling +// code is operating as expected. + +// CHECK-LABEL: func @add_same_for_both +// CHECK: %[[true:.*]] = index.bool.constant true +// CHECK: return %[[true]] +func.func @add_same_for_both(%arg0 : index) -> i1 { + %c1 = index.constant 1 + %calmostBig = index.constant 0xfffffffe + %0 = index.minu %arg0, %calmostBig + %1 = index.add %0, %c1 + %2 = index.cmp uge(%1, %c1) + func.return %2 : i1 +} + +// CHECK-LABEL: func @add_unsigned_ov +// CHECK: %[[uge:.*]] = index.cmp uge +// CHECK: return %[[uge]] +func.func @add_unsigned_ov(%arg0 : index) -> i1 { + %c1 = index.constant 1 + %cu32_max = index.constant 0xffffffff + %0 = index.minu %arg0, %cu32_max + %1 = index.add %0, %c1 + // On 32-bit, the add could wrap, so the result doesn't have to be >= 1 + %2 = index.cmp uge(%1, %c1) + func.return %2 : i1 +} + +// CHECK-LABEL: func @add_signed_ov +// CHECK: %[[sge:.*]] = index.cmp sge +// CHECK: return %[[sge]] +func.func @add_signed_ov(%arg0 : index) -> i1 { + %c0 = index.constant 0 + %c1 = index.constant 1 + %ci32_max = index.constant 0x7fffffff + %0 = index.minu %arg0, %ci32_max + %1 = index.add %0, %c1 + // On 32-bit, the add could wrap, so the result doesn't have to be positive + %2 = index.cmp sge(%1, %c0) + func.return %2 : i1 +} + +// CHECK-LABEL: func @add_big +// CHECK: %[[true:.*]] = index.bool.constant true +// CHECK: return %[[true]] +func.func @add_big(%arg0 : index) -> i1 { + %c1 = index.constant 1 + %cmin = index.constant 0x300000000 + %cmax = index.constant 0x30000ffff + // Note: the order of the clamps matters. + // If you go max, then min, you infer the ranges [0x300...0, 0xff..ff] + // and then [0x30...0000, 0x30...ffff] + // If you switch the order of the below operations, you instead first infer + // the range [0,0x3...ffff]. Then, the min inference can't constraint + // this intermediate, since in the 32-bit case we could have, for example + // trunc(%arg0 = 0x2ffffffff) = 0xffffffff > trunc(0x30000ffff) = 0x0000ffff + // which means we can't do any inference. + %0 = index.maxu %arg0, %cmin + %1 = index.minu %0, %cmax + %2 = index.add %1, %c1 + %3 = index.cmp uge(%1, %cmin) + func.return %3 : i1 +}