diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h --- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h @@ -12,6 +12,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -11,6 +11,7 @@ include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" @@ -49,7 +50,8 @@ // Base class for integer binary operations. class Arith_IntBinaryOp traits = []> : - Arith_BinaryOp, + Arith_BinaryOp]>, Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>, Results<(outs SignlessIntegerLike:$result)>; @@ -70,7 +72,7 @@ class Arith_CastOp traits = []> : Arith_Op]>, + DeclareOpInterfaceMethods]>, Arguments<(ins From:$in)>, Results<(outs To:$out)> { let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)"; @@ -87,7 +89,9 @@ // Cast from an integer type to another integer type. class Arith_IToICastOp traits = []> : Arith_CastOp; + SignlessFixedWidthIntegerLike, + traits # + [DeclareOpInterfaceMethods]>; // Cast from an integer type to a floating point type. class Arith_IToFCastOp traits = []> : Arith_CastOp; @@ -124,7 +128,8 @@ def Arith_ConstantOp : Op, - AllTypesMatch<["value", "result"]>]> { + AllTypesMatch<["value", "result"]>, + DeclareOpInterfaceMethods]> { let summary = "integer or floating point constant"; let description = [{ The `constant` operation produces an SSA value equal to some integer or @@ -971,8 +976,9 @@ MemRefOf<[AnySignlessInteger, Index]>.predicate]>, "signless-integer-like or memref of signless-integer">; -def Arith_IndexCastOp : Arith_CastOp<"index_cast", IndexCastTypeConstraint, - IndexCastTypeConstraint> { +def Arith_IndexCastOp + : Arith_CastOp<"index_cast", IndexCastTypeConstraint, IndexCastTypeConstraint, + [DeclareOpInterfaceMethods]> { let summary = "cast between index and integer types"; let description = [{ Casts between scalar or vector integers and corresponding 'index' scalar or @@ -1024,7 +1030,9 @@ // CmpIOp //===----------------------------------------------------------------------===// -def Arith_CmpIOp : Arith_CompareOpOfAnyRank<"cmpi"> { +def Arith_CmpIOp + : Arith_CompareOpOfAnyRank<"cmpi", + [DeclareOpInterfaceMethods]> { let summary = "integer comparison operation"; let description = [{ The `cmpi` operation is a generic comparison for integer-like types. Its two @@ -1165,7 +1173,8 @@ //===----------------------------------------------------------------------===// def SelectOp : Arith_Op<"select", [ - AllTypesMatch<["true_value", "false_value", "result"]> + AllTypesMatch<["true_value", "false_value", "result"]>, + DeclareOpInterfaceMethods, ] # ElementwiseMappable.traits> { let summary = "select operation"; let description = [{ @@ -1205,7 +1214,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; - + // FIXME: Switch this to use the declarative assembly format. let hasCustomAssemblyFormat = 1; } diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h --- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h +++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h @@ -56,25 +56,40 @@ /// non-integer types this is 0. static unsigned getStorageBitwidth(Type type); - /// Create an `IntRangeAttrs` where `min` is both the signed and unsigned - /// minimum and `max` is both the signed and unsigned maximum. - static ConstantIntRanges range(const APInt &min, const APInt &max); - - /// Create an `IntRangeAttrs` with the signed minimum and maximum equal + /// Create a `ConstantIntRanges` with the maximum bounds for the width + /// `bitwidth`, that is - [0, uint_max(width)]/[sint_min(width), + /// sint_max(width)]. + static ConstantIntRanges maxRange(unsigned bitwidth); + + /// Create a `ConstantIntRanges` with a constant value - that is, with the + /// bounds [value, value] for both its signed interpretations. + static ConstantIntRanges constant(const APInt &value); + + /// Create a `ConstantIntRanges` whose minimum is `min` and maximum is `max` + /// with `isSigned` specifying if the min and max should be interpreted as + /// signed or unsigned. + static ConstantIntRanges range(const APInt &min, const APInt &max, + bool isSigned); + + /// Create an `ConstantIntRanges` with the signed minimum and maximum equal /// to `smin` and `smax`, where the unsigned bounds are constructed from the /// signed ones if they correspond to a contigious range of bit patterns when /// viewed as unsigned values and are left at [0, int_max()] otherwise. static ConstantIntRanges fromSigned(const APInt &smin, const APInt &smax); - /// Create an `IntRangeAttrs` with the unsigned minimum and maximum equal + /// Create an `ConstantIntRanges` with the unsigned minimum and maximum equal /// to `umin` and `umax` and the signed part equal to `umin` and `umax` /// unless the sign bit changes between the minimum and maximum. static ConstantIntRanges fromUnsigned(const APInt &umin, const APInt &umax); /// Returns the union (computed separately for signed and unsigned bounds) - /// of `a` and `b`. + /// of this range and `other`. ConstantIntRanges rangeUnion(const ConstantIntRanges &other) const; + /// Returns the intersection (computed separately for signed and unsigned + /// bounds) of this range and `other`. + ConstantIntRanges intersection(const ConstantIntRanges &other) const; + /// If either the signed or unsigned interpretations of the range /// indicate that the value it bounds is a constant, return that constant /// value. diff --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp --- a/mlir/lib/Analysis/IntRangeAnalysis.cpp +++ b/mlir/lib/Analysis/IntRangeAnalysis.cpp @@ -43,7 +43,7 @@ /// value being marked overdefined is even an integer. static IntRangeLattice getPessimisticValueState(MLIRContext *context) { APInt noIntValue = APInt::getZeroWidth(); - return ConstantIntRanges::range(noIntValue, noIntValue); + return ConstantIntRanges(noIntValue, noIntValue, noIntValue, noIntValue); } /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)]) diff --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRArithmeticDialect ArithmeticOps.cpp ArithmeticDialect.cpp + InferIntRangeInterfaceImpls.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic @@ -14,6 +15,7 @@ LINK_LIBS PUBLIC MLIRDialect + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR ) diff --git a/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp @@ -0,0 +1,660 @@ +//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "int-range-analysis" + +using namespace mlir; +using namespace mlir::arith; + +/// Function that evaluates the result of doing something on arithmetic +/// constants and returns None on overflow. +using ConstArithFn = + function_ref(const APInt &, const APInt &)>; + +/// Return the maxmially wide signed or unsigned range for a given bitwidth. + +/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, +/// If either computation overflows, make the result unbounded. +static ConstantIntRanges computeBoundsBy(ConstArithFn op, const APInt &minLeft, + const APInt &minRight, + const APInt &maxLeft, + const APInt &maxRight, bool isSigned) { + Optional maybeMin = op(minLeft, minRight); + Optional maybeMax = op(maxLeft, maxRight); + if (maybeMin.hasValue() && maybeMax.hasValue()) + return ConstantIntRanges::range(*maybeMin, *maybeMax, isSigned); + return ConstantIntRanges::maxRange(minLeft.getBitWidth()); +} + +/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, +/// ignoring unbounded values. Returns the maximal range if `op` overflows. +static ConstantIntRanges minMaxBy(ConstArithFn op, ArrayRef lhs, + ArrayRef rhs, bool isSigned) { + unsigned width = lhs[0].getBitWidth(); + APInt min = + isSigned ? APInt::getSignedMaxValue(width) : APInt::getMaxValue(width); + APInt max = + isSigned ? APInt::getSignedMinValue(width) : APInt::getZero(width); + for (const APInt &left : lhs) { + for (const APInt &right : rhs) { + Optional maybeThisResult = op(left, right); + if (!maybeThisResult) + return ConstantIntRanges::maxRange(width); + APInt result = std::move(*maybeThisResult); + min = (isSigned ? result.slt(min) : result.ult(min)) ? result : min; + max = (isSigned ? result.sgt(max) : result.ugt(max)) ? result : max; + } + } + return ConstantIntRanges::range(min, max, isSigned); +} + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +void arith::ConstantOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + auto constAttr = getValue().dyn_cast_or_null(); + if (constAttr) { + const APInt &value = constAttr.getValue(); + setResultRange(getResult(), ConstantIntRanges::constant(value)); + } +} + +//===----------------------------------------------------------------------===// +// AddIOp +//===----------------------------------------------------------------------===// + +void arith::AddIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + ConstArithFn uadd = [](const APInt &a, const APInt &b) -> Optional { + bool overflowed = false; + APInt result = a.uadd_ov(b, overflowed); + return overflowed ? Optional() : result; + }; + ConstArithFn sadd = [](const APInt &a, const APInt &b) -> Optional { + bool overflowed = false; + APInt result = a.sadd_ov(b, overflowed); + return overflowed ? Optional() : result; + }; + + ConstantIntRanges urange = computeBoundsBy( + uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false); + ConstantIntRanges srange = computeBoundsBy( + sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true); + setResultRange(getResult(), urange.intersection(srange)); +} + +//===----------------------------------------------------------------------===// +// SubIOp +//===----------------------------------------------------------------------===// + +void arith::SubIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn usub = [](const APInt &a, const APInt &b) -> Optional { + bool overflowed = false; + APInt result = a.usub_ov(b, overflowed); + return overflowed ? Optional() : result; + }; + ConstArithFn ssub = [](const APInt &a, const APInt &b) -> Optional { + bool overflowed = false; + APInt result = a.ssub_ov(b, overflowed); + return overflowed ? Optional() : result; + }; + ConstantIntRanges urange = computeBoundsBy( + usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false); + ConstantIntRanges srange = computeBoundsBy( + ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true); + setResultRange(getResult(), urange.intersection(srange)); +} + +//===----------------------------------------------------------------------===// +// MulIOp +//===----------------------------------------------------------------------===// + +void arith::MulIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn umul = [](const APInt &a, const APInt &b) -> Optional { + bool overflowed = false; + APInt result = a.umul_ov(b, overflowed); + return overflowed ? Optional() : result; + }; + ConstArithFn smul = [](const APInt &a, const APInt &b) -> Optional { + bool overflowed = false; + APInt result = a.smul_ov(b, overflowed); + return overflowed ? Optional() : result; + }; + + ConstantIntRanges urange = + minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/false); + ConstantIntRanges srange = + minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, + /*isSigned=*/true); + + setResultRange(getResult(), urange.intersection(srange)); +} + +//===----------------------------------------------------------------------===// +// DivUIOp +//===----------------------------------------------------------------------===// + +/// Fix up division results (ex. for ceiling and floor), returning an APInt +/// if there has been no overflow +using DivisionFixupFn = function_ref( + const APInt &lhs, const APInt &rhs, const APInt &result)>; + +static ConstantIntRanges inferDivUIRange(const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs, + DivisionFixupFn fixup) { + const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), + &rhsMax = rhs.umax(); + + if (!rhsMin.isZero()) { + auto udiv = [&fixup](const APInt &a, const APInt &b) -> Optional { + return fixup(a, b, a.udiv(b)); + }; + return minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*isSigned=*/false); + } + // Otherwise, it's possible we might divide by 0. + return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); +} + +void arith::DivUIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferDivUIRange(argRanges[0], argRanges[1], + [](const APInt &lhs, const APInt &rhs, + const APInt &result) { return result; })); +} + +//===----------------------------------------------------------------------===// +// DivSIOp +//===----------------------------------------------------------------------===// + +static ConstantIntRanges inferDivSIRange(const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs, + DivisionFixupFn fixup) { + const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), + &rhsMax = rhs.smax(); + bool canDivide = rhsMin.isStrictlyPositive() || rhsMax.isNegative(); + + if (canDivide) { + auto sdiv = [&fixup](const APInt &a, const APInt &b) -> Optional { + bool overflowed = false; + APInt result = a.sdiv_ov(b, overflowed); + return overflowed ? Optional() : fixup(a, b, result); + }; + return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*isSigned=*/true); + } + return ConstantIntRanges::maxRange(rhsMin.getBitWidth()); +} + +void arith::DivSIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + setResultRange(getResult(), + inferDivSIRange(argRanges[0], argRanges[1], + [](const APInt &lhs, const APInt &rhs, + const APInt &result) { return result; })); +} + +//===----------------------------------------------------------------------===// +// CeilDivUIOp +//===----------------------------------------------------------------------===// + +void arith::CeilDivUIOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn ceilDivUIFix = [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> Optional { + if (!lhs.urem(rhs).isZero()) { + bool overflowed = false; + APInt corrected = + result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? Optional() : corrected; + } + return result; + }; + setResultRange(getResult(), inferDivUIRange(lhs, rhs, ceilDivUIFix)); +} + +//===----------------------------------------------------------------------===// +// CeilDivSIOp +//===----------------------------------------------------------------------===// + +void arith::CeilDivSIOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn ceilDivSIFix = [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> Optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? Optional() : corrected; + } + return result; + }; + setResultRange(getResult(), inferDivSIRange(lhs, rhs, ceilDivSIFix)); +} + +//===----------------------------------------------------------------------===// +// FloorDivSIOp +//===----------------------------------------------------------------------===// + +void arith::FloorDivSIOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + DivisionFixupFn floorDivSIFix = [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> Optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); + return overflowed ? Optional() : corrected; + } + return result; + }; + setResultRange(getResult(), inferDivSIRange(lhs, rhs, floorDivSIFix)); +} + +//===----------------------------------------------------------------------===// +// RemUIOp +//===----------------------------------------------------------------------===// + +void arith::RemUIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); + + unsigned width = rhsMin.getBitWidth(); + APInt umin = APInt::getZero(width); + APInt umax = APInt::getMaxValue(width); + + if (!rhsMin.isZero()) { + umax = rhsMax - 1; + // Special case: sweeping out a contiguous range in N/[modulus] + if (rhsMin == rhsMax) { + const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(); + if ((lhsMax - lhsMin).ult(rhsMax)) { + APInt minRem = lhsMin.urem(rhsMax); + APInt maxRem = lhsMax.urem(rhsMax); + if (minRem.ule(maxRem)) { + umin = minRem; + umax = maxRem; + } + } + } + } + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); +} + +//===----------------------------------------------------------------------===// +// RemSIOp +//===----------------------------------------------------------------------===// + +void arith::RemSIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(), + &rhsMax = rhs.smax(); + + unsigned width = rhsMax.getBitWidth(); + APInt smin = APInt::getSignedMinValue(width); + APInt smax = APInt::getSignedMaxValue(width); + // No bounds if zero could be a divisor. + bool canBound = (rhsMin.isStrictlyPositive() || rhsMax.isNegative()); + if (canBound) { + APInt maxDivisor = rhsMin.isStrictlyPositive() ? rhsMax : rhsMin.abs(); + bool canNegativeDividend = lhsMin.isNegative(); + bool canPositiveDividend = lhsMax.isStrictlyPositive(); + APInt zero = APInt::getZero(maxDivisor.getBitWidth()); + APInt maxPositiveResult = maxDivisor - 1; + APInt minNegativeResult = -maxPositiveResult; + smin = canNegativeDividend ? minNegativeResult : zero; + smax = canPositiveDividend ? maxPositiveResult : zero; + // Special case: sweeping out a contiguous range in N/[modulus]. + if (rhsMin == rhsMax) { + if ((lhsMax - lhsMin).ult(maxDivisor)) { + APInt minRem = lhsMin.srem(maxDivisor); + APInt maxRem = lhsMax.srem(maxDivisor); + if (minRem.sle(maxRem)) { + smin = minRem; + smax = maxRem; + } + } + } + } + setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); +} + +//===----------------------------------------------------------------------===// +// AndIOp +//===----------------------------------------------------------------------===// + +/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, +/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits +/// that both bonuds have in common. This gives us a consertive approximation +/// for what values can be passed to bitwise operations. +static std::tuple +widenBitwiseBounds(const ConstantIntRanges &bound) { + APInt leftVal = bound.umin(), rightVal = bound.umax(); + unsigned bitwidth = leftVal.getBitWidth(); + unsigned differingBits = bitwidth - (leftVal ^ rightVal).countLeadingZeros(); + leftVal.clearLowBits(differingBits); + rightVal.setLowBits(differingBits); + return {leftVal, rightVal}; +} + +void arith::AndIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes; + std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]); + std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]); + auto andi = [](const APInt &a, const APInt &b) -> Optional { + return a & b; + }; + setResultRange(getResult(), + minMaxBy(andi, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false)); +} + +//===----------------------------------------------------------------------===// +// OrIOp +//===----------------------------------------------------------------------===// + +void arith::OrIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes; + std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]); + std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]); + auto ori = [](const APInt &a, const APInt &b) -> Optional { + return a | b; + }; + setResultRange(getResult(), + minMaxBy(ori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false)); +} + +//===----------------------------------------------------------------------===// +// XOrIOp +//===----------------------------------------------------------------------===// + +void arith::XOrIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + APInt lhsZeros, lhsOnes, rhsZeros, rhsOnes; + std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(argRanges[0]); + std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(argRanges[1]); + auto xori = [](const APInt &a, const APInt &b) -> Optional { + return a ^ b; + }; + setResultRange(getResult(), + minMaxBy(xori, {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, + /*isSigned=*/false)); +} + +//===----------------------------------------------------------------------===// +// MaxSIOp +//===----------------------------------------------------------------------===// + +void arith::MaxSIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin(); + const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax(); + setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); +} + +//===----------------------------------------------------------------------===// +// MaxUIOp +//===----------------------------------------------------------------------===// + +void arith::MaxUIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &umin = lhs.umin().ugt(rhs.umin()) ? lhs.umin() : rhs.umin(); + const APInt &umax = lhs.umax().ugt(rhs.umax()) ? lhs.umax() : rhs.umax(); + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); +} + +//===----------------------------------------------------------------------===// +// MinSIOp +//===----------------------------------------------------------------------===// + +void arith::MinSIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &smin = lhs.smin().slt(rhs.smin()) ? lhs.smin() : rhs.smin(); + const APInt &smax = lhs.smax().slt(rhs.smax()) ? lhs.smax() : rhs.smax(); + setResultRange(getResult(), ConstantIntRanges::fromSigned(smin, smax)); +} + +//===----------------------------------------------------------------------===// +// MinUIOp +//===----------------------------------------------------------------------===// + +void arith::MinUIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + const APInt &umin = lhs.umin().ult(rhs.umin()) ? lhs.umin() : rhs.umin(); + const APInt &umax = lhs.umax().ult(rhs.umax()) ? lhs.umax() : rhs.umax(); + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); +} + +//===----------------------------------------------------------------------===// +// ExtUIOp +//===----------------------------------------------------------------------===// + +void arith::ExtUIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + Type destType = getResult().getType(); + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + APInt umin = argRanges[0].umin().zext(destWidth); + APInt umax = argRanges[0].umax().zext(destWidth); + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(umin, umax)); +} + +//===----------------------------------------------------------------------===// +// ExtSIOp +//===----------------------------------------------------------------------===// + +static ConstantIntRanges extSIRange(const ConstantIntRanges &range, + Type destType) { + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + APInt smin = range.smin().sext(destWidth); + APInt smax = range.smax().sext(destWidth); + return ConstantIntRanges::fromSigned(smin, smax); +} + +void arith::ExtSIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + Type destType = getResult().getType(); + setResultRange(getResult(), extSIRange(argRanges[0], destType)); +} + +//===----------------------------------------------------------------------===// +// TruncIOp +//===----------------------------------------------------------------------===// + +static ConstantIntRanges truncIRange(const ConstantIntRanges &range, + Type destType) { + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + APInt umin = range.umin().trunc(destWidth); + APInt umax = range.umax().trunc(destWidth); + APInt smin = range.smin().trunc(destWidth); + APInt smax = range.smax().trunc(destWidth); + return {umin, umax, smin, smax}; +} + +void arith::TruncIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + Type destType = getResult().getType(); + setResultRange(getResult(), truncIRange(argRanges[0], destType)); +} + +//===----------------------------------------------------------------------===// +// IndexCastOp +//===----------------------------------------------------------------------===// + +void arith::IndexCastOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRange) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); + unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); + + if (srcWidth < destWidth) + setResultRange(getResult(), extSIRange(argRanges[0], destType)); + else if (srcWidth > destWidth) + setResultRange(getResult(), truncIRange(argRanges[0], destType)); + else + setResultRange(getResult(), argRanges[0]); +} + +//===----------------------------------------------------------------------===// +// CmpIOp +//===----------------------------------------------------------------------===// + +bool isStaticallyTrue(arith::CmpIPredicate pred, const ConstantIntRanges &lhs, + const ConstantIntRanges &rhs) { + switch (pred) { + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::slt: + return (applyCmpPredicate(pred, lhs.smax(), rhs.smin())); + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::ult: + return applyCmpPredicate(pred, lhs.umax(), rhs.umin()); + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::sgt: + return applyCmpPredicate(pred, lhs.smin(), rhs.smax()); + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::ugt: + return applyCmpPredicate(pred, lhs.umin(), rhs.umax()); + case arith::CmpIPredicate::eq: { + Optional lhsConst = lhs.getConstantValue(); + Optional rhsConst = rhs.getConstantValue(); + return lhsConst && rhsConst && lhsConst == rhsConst; + } + case arith::CmpIPredicate::ne: { + // While equality requires that there is an interpration of the preceeding + // computations that produces equal constants, whether that be signed or + // unsigned, statically determining inequality requires that neither + // interpretation produce potentially overlapping ranges. + bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) || + isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs); + bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) || + isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs); + return sne && une; + } + } + return false; +} + +void arith::CmpIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + arith::CmpIPredicate pred = getPredicate(); + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + APInt min = APInt::getZero(1); + APInt max = APInt::getAllOnesValue(1); + if (isStaticallyTrue(pred, lhs, rhs)) + min = max; + else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) + max = min; + + setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); +} + +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +void arith::SelectOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + Optional mbCondVal = argRanges[0].getConstantValue(); + + if (mbCondVal) { + if (mbCondVal->isZero()) + setResultRange(getResult(), argRanges[2]); + else + setResultRange(getResult(), argRanges[1]); + return; + } + setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2])); +} + +//===----------------------------------------------------------------------===// +// ShLIOp +//===----------------------------------------------------------------------===// + +void arith::ShLIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + ConstArithFn shl = [](const APInt &l, const APInt &r) -> Optional { + return r.uge(r.getBitWidth()) ? Optional() : l.shl(r); + }; + ConstantIntRanges urange = + minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/false); + ConstantIntRanges srange = + minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, + /*isSigned=*/true); + setResultRange(getResult(), urange.intersection(srange)); +} + +//===----------------------------------------------------------------------===// +// ShRUIOp +//===----------------------------------------------------------------------===// + +void arith::ShRUIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn lshr = [](const APInt &l, const APInt &r) -> Optional { + return r.uge(r.getBitWidth()) ? Optional() : l.lshr(r); + }; + setResultRange(getResult(), minMaxBy(lshr, {lhs.umin(), lhs.umax()}, + {rhs.umin(), rhs.umax()}, + /*isSigned=*/false)); +} + +//===----------------------------------------------------------------------===// +// ShRSIOp +//===----------------------------------------------------------------------===// + +void arith::ShRSIOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; + + ConstArithFn ashr = [](const APInt &l, const APInt &r) -> Optional { + return r.uge(r.getBitWidth()) ? Optional() : l.ashr(r); + }; + + setResultRange(getResult(), + minMaxBy(ashr, {lhs.smin(), lhs.smax()}, + {rhs.umin(), rhs.umax()}, /*isSigned=*/true)); +} diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp --- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp +++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp @@ -35,8 +35,19 @@ return 0; } -ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max) { - return {min, max, min, max}; +ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) { + return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth)); +} + +ConstantIntRanges ConstantIntRanges::constant(const APInt &value) { + return {value, value, value, value}; +} + +ConstantIntRanges ConstantIntRanges::range(const APInt &min, const APInt &max, + bool isSigned) { + if (isSigned) + return fromSigned(min, max); + return fromUnsigned(min, max); } ConstantIntRanges ConstantIntRanges::fromSigned(const APInt &smin, @@ -84,6 +95,23 @@ return {uminUnion, umaxUnion, sminUnion, smaxUnion}; } +ConstantIntRanges +ConstantIntRanges::intersection(const ConstantIntRanges &other) const { + // "Not an integer" poisons everything and also cannot be fed to comparison + // operators. + if (umin().getBitWidth() == 0) + return *this; + if (other.umin().getBitWidth() == 0) + return other; + + const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin(); + const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax(); + const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin(); + const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax(); + + return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect}; +} + Optional ConstantIntRanges::getConstantValue() const { // Note: we need to exclude the trivially-equal width 0 values here. if (umin() == umax() && umin().getBitWidth() != 0) diff --git a/mlir/test/Dialect/Arithmetic/int-range-interface.mlir b/mlir/test/Dialect/Arithmetic/int-range-interface.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arithmetic/int-range-interface.mlir @@ -0,0 +1,647 @@ +// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s + +// CHECK-LABEL: func @add_min_max +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: return %[[c3]] +func.func @add_min_max(%a: index, %b: index) -> index { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = arith.minsi %a, %c1 : index + %1 = arith.maxsi %0, %c1 : index + %2 = arith.minui %b, %c2 : index + %3 = arith.maxui %2, %c2 : index + %4 = arith.addi %1, %3 : index + func.return %4 : index +} + +// CHECK-LABEL: func @add_lower_bound +// CHECK: %[[sge:.*]] = arith.cmpi sge +// CHECK: return %[[sge]] +func.func @add_lower_bound(%a : i32, %b : i32) -> i1 { + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %0 = arith.maxsi %a, %c1 : i32 + %1 = arith.maxsi %b, %c1 : i32 + %2 = arith.addi %0, %1 : i32 + %3 = arith.cmpi sge, %2, %c2 : i32 + %4 = arith.cmpi uge, %2, %c2 : i32 + %5 = arith.andi %3, %4 : i1 + func.return %5 : i1 +} + +// CHECK-LABEL: func @sub_signed_vs_unsigned +// CHECK-NOT: arith.cmpi sle +// CHECK: %[[unsigned:.*]] = arith.cmpi ule +// CHECK: return %[[unsigned]] : i1 +func.func @sub_signed_vs_unsigned(%v : i64) -> i1 { + %c0 = arith.constant 0 : i64 + %c2 = arith.constant 2 : i64 + %c-5 = arith.constant -5 : i64 + %0 = arith.minsi %v, %c2 : i64 + %1 = arith.maxsi %0, %c-5 : i64 + %2 = arith.subi %1, %c2 : i64 + %3 = arith.cmpi sle, %2, %c0 : i64 + %4 = arith.cmpi ule, %2, %c0 : i64 + %5 = arith.andi %3, %4 : i1 + func.return %5 : i1 +} + +// CHECK-LABEL: func @multiply_negatives +// CHECK: %[[false:.*]] = arith.constant false +// CHECK: return %[[false]] +func.func @multiply_negatives(%a : index, %b : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c_1 = arith.constant -1 : index + %c_2 = arith.constant -2 : index + %c_4 = arith.constant -4 : index + %c_12 = arith.constant -12 : index + %0 = arith.maxsi %a, %c2 : index + %1 = arith.minsi %0, %c3 : index + %2 = arith.minsi %b, %c_1 : index + %3 = arith.maxsi %2, %c_4 : index + %4 = arith.muli %1, %3 : index + %5 = arith.cmpi slt, %4, %c_12 : index + %6 = arith.cmpi slt, %c_1, %4 : index + %7 = arith.ori %5, %6 : i1 + func.return %7 : i1 +} + +// CHECK-LABEL: func @multiply_unsigned_bounds +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @multiply_unsigned_bounds(%a : i16, %b : i16) -> i1 { + %c0 = arith.constant 0 : i16 + %c4 = arith.constant 4 : i16 + %c_mask = arith.constant 0x3fff : i16 + %c_bound = arith.constant 0xfffc : i16 + %0 = arith.andi %a, %c_mask : i16 + %1 = arith.minui %b, %c4 : i16 + %2 = arith.muli %0, %1 : i16 + %3 = arith.cmpi uge, %2, %c0 : i16 + %4 = arith.cmpi ule, %2, %c_bound : i16 + %5 = arith.andi %3, %4 : i1 + func.return %5 : i1 +} + +// CHECK-LABEL: @for_loop_with_increasing_arg +// CHECK: %[[ret:.*]] = arith.cmpi ule +// CHECK: return %[[ret]] +func.func @for_loop_with_increasing_arg() -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %0 = scf.for %arg0 = %c0 to %c4 step %c1 iter_args(%arg1 = %c0) -> index { + %10 = arith.addi %arg0, %arg1 : index + scf.yield %10 : index + } + %1 = arith.cmpi ule, %0, %c16 : index + func.return %1 : i1 +} + +// CHECK-LABEL: @for_loop_with_constant_result +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @for_loop_with_constant_result() -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %true = arith.constant true + %0 = scf.for %arg0 = %c0 to %c4 step %c1 iter_args(%arg1 = %true) -> i1 { + %10 = arith.cmpi ule, %arg0, %c4 : index + %11 = arith.andi %10, %arg1 : i1 + scf.yield %11 : i1 + } + func.return %0 : i1 +} + +// Test to catch a bug present in some versions of the data flow analysis +// CHECK-LABEL: func @while_false +// CHECK: %[[false:.*]] = arith.constant false +// CHECK: scf.condition(%[[false]]) +func.func @while_false(%arg0 : index) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = arith.divui %arg0, %c2 : index + %1 = scf.while (%arg1 = %0) : (index) -> index { + %2 = arith.cmpi slt, %arg1, %c0 : index + scf.condition(%2) %arg1 : index + } do { + ^bb0(%arg2 : index): + scf.yield %c2 : index + } + func.return %1 : index +} + +// CHECK-LABEL: func @div_bounds_positive +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @div_bounds_positive(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %0 = arith.maxsi %arg0, %c2 : index + %1 = arith.divsi %c4, %0 : index + %2 = arith.divui %c4, %0 : index + + %3 = arith.cmpi sge, %1, %c0 : index + %4 = arith.cmpi sle, %1, %c2 : index + %5 = arith.cmpi sge, %2, %c0 : index + %6 = arith.cmpi sle, %1, %c2 : index + + %7 = arith.andi %3, %4 : i1 + %8 = arith.andi %7, %5 : i1 + %9 = arith.andi %8, %6 : i1 + func.return %9 : i1 +} + +// CHECK-LABEL: func @div_bounds_negative +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @div_bounds_negative(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c_2 = arith.constant -2 : index + %c4 = arith.constant 4 : index + %0 = arith.minsi %arg0, %c_2 : index + %1 = arith.divsi %c4, %0 : index + %2 = arith.divui %c4, %0 : index + + %3 = arith.cmpi sle, %1, %c0 : index + %4 = arith.cmpi sge, %1, %c_2 : index + %5 = arith.cmpi eq, %2, %c0 : index + + %7 = arith.andi %3, %4 : i1 + %8 = arith.andi %7, %5 : i1 + func.return %8 : i1 +} + +// CHECK-LABEL: func @div_zero_undefined +// CHECK: %[[ret:.*]] = arith.cmpi ule +// CHECK: return %[[ret]] +func.func @div_zero_undefined(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %0 = arith.andi %arg0, %c1 : index + %1 = arith.divui %c4, %0 : index + %2 = arith.cmpi ule, %1, %c4 : index + func.return %2 : i1 +} + +// CHECK-LABEL: func @ceil_divui +// CHECK: %[[ret:.*]] = arith.cmpi eq +// CHECK: return %[[ret]] +func.func @ceil_divui(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + + %0 = arith.minui %arg0, %c3 : index + %1 = arith.maxui %0, %c1 : index + %2 = arith.ceildivui %1, %c4 : index + %3 = arith.cmpi eq, %2, %c1 : index + + %4 = arith.maxui %0, %c0 : index + %5 = arith.ceildivui %4, %c4 : index + %6 = arith.cmpi eq, %5, %c1 : index + %7 = arith.andi %3, %6 : i1 + func.return %7 : i1 +} + +// CHECK-LABEL: func @ceil_divsi +// CHECK: %[[ret:.*]] = arith.cmpi eq +// CHECK: return %[[ret]] +func.func @ceil_divsi(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c3 : index + %1 = arith.maxsi %0, %c1 : index + %2 = arith.ceildivsi %1, %c4 : index + %3 = arith.cmpi eq, %2, %c1 : index + %4 = arith.ceildivsi %1, %c-4 : index + %5 = arith.cmpi eq, %4, %c0 : index + %6 = arith.andi %3, %5 : i1 + + %7 = arith.maxsi %0, %c0 : index + %8 = arith.ceildivsi %7, %c4 : index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = arith.andi %6, %9 : i1 + func.return %10 : i1 +} + +// CHECK-LABEL: func @floor_divsi +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @floor_divsi(%arg0 : index) -> i1 { + %c4 = arith.constant 4 : index + %c-1 = arith.constant -1 : index + %c-3 = arith.constant -3 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c-1 : index + %1 = arith.maxsi %0, %c-4 : index + %2 = arith.floordivsi %1, %c4 : index + %3 = arith.cmpi eq, %2, %c-1 : index + func.return %3 : i1 +} + +// CHECK-LABEL: func @remui_base +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @remui_base(%arg0 : index, %arg1 : index ) -> i1 { + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + %0 = arith.minui %arg1, %c4 : index + %1 = arith.maxui %0, %c2 : index + %2 = arith.remui %arg0, %1 : index + %3 = arith.cmpi ult, %2, %c4 : index + func.return %3 : i1 +} + +// CHECK-LABEL: func @remsi_base +// CHECK: %[[ret:.*]] = arith.cmpi sge +// CHECK: return %[[ret]] +func.func @remsi_base(%arg0 : index, %arg1 : index ) -> i1 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c-4 = arith.constant -4 : index + %true = arith.constant true + + %0 = arith.minsi %arg1, %c4 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.remsi %arg0, %1 : index + %3 = arith.cmpi sgt, %2, %c-4 : index + %4 = arith.cmpi slt, %2, %c4 : index + %5 = arith.cmpi sge, %2, %c0 : index + %6 = arith.andi %3, %4 : i1 + %7 = arith.andi %5, %6 : i1 + func.return %7 : i1 +} + +// CHECK-LABEL: func @remsi_positive +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @remsi_positive(%arg0 : index, %arg1 : index ) -> i1 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %true = arith.constant true + + %0 = arith.minsi %arg1, %c4 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.maxsi %arg0, %c0 : index + %3 = arith.remsi %2, %1 : index + %4 = arith.cmpi sge, %3, %c0 : index + %5 = arith.cmpi slt, %3, %c4 : index + %6 = arith.andi %4, %5 : i1 + func.return %6 : i1 +} + +// CHECK-LABEL: func @remui_restricted +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @remui_restricted(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + + %0 = arith.minui %arg0, %c3 : index + %1 = arith.maxui %0, %c2 : index + %2 = arith.remui %1, %c4 : index + %3 = arith.cmpi ule, %2, %c3 : index + %4 = arith.cmpi uge, %2, %c2 : index + %5 = arith.andi %3, %4 : i1 + func.return %5 : i1 +} + +// CHECK-LABEL: func @remsi_restricted +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @remsi_restricted(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c3 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.remsi %1, %c-4 : index + %3 = arith.cmpi ule, %2, %c3 : index + %4 = arith.cmpi uge, %2, %c2 : index + %5 = arith.andi %3, %4 : i1 + func.return %5 : i1 +} + +// CHECK-LABEL: func @remui_restricted_fails +// CHECK: %[[ret:.*]] = arith.cmpi ne +// CHECK: return %[[ret]] +func.func @remui_restricted_fails(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + + %0 = arith.minui %arg0, %c5 : index + %1 = arith.maxui %0, %c3 : index + %2 = arith.remui %1, %c4 : index + %3 = arith.cmpi ne, %2, %c2 : index + func.return %3 : i1 +} + +// CHECK-LABEL: func @remsi_restricted_fails +// CHECK: %[[ret:.*]] = arith.cmpi ne +// CHECK: return %[[ret]] +func.func @remsi_restricted_fails(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c5 = arith.constant 5 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c5 : index + %1 = arith.maxsi %0, %c3 : index + %2 = arith.remsi %1, %c-4 : index + %3 = arith.cmpi ne, %2, %c2 : index + func.return %3 : i1 +} + +// CHECK-LABEL: func @andi +// CHECK: %[[ret:.*]] = arith.cmpi ugt +// CHECK: return %[[ret]] +func.func @andi(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c5 = arith.constant 5 : index + %c7 = arith.constant 7 : index + + %0 = arith.minsi %arg0, %c5 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.andi %1, %c7 : index + %3 = arith.cmpi ugt, %2, %c5 : index + %4 = arith.cmpi ule, %2, %c7 : index + %5 = arith.andi %3, %4 : i1 + func.return %5 : i1 +} + +// CHECK-LABEL: func @andi_doesnt_make_nonnegative +// CHECK: %[[ret:.*]] = arith.cmpi sge +// CHECK: return %[[ret]] +func.func @andi_doesnt_make_nonnegative(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = arith.addi %arg0, %c1 : index + %1 = arith.andi %arg0, %0 : index + %2 = arith.cmpi sge, %1, %c0 : index + func.return %2 : i1 +} + + +// CHECK-LABEL: func @ori +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @ori(%arg0 : i128, %arg1 : i128) -> i1 { + %c-1 = arith.constant -1 : i128 + %c0 = arith.constant 0 : i128 + + %0 = arith.minsi %arg1, %c-1 : i128 + %1 = arith.ori %arg0, %0 : i128 + %2 = arith.cmpi slt, %1, %c0 : i128 + func.return %2 : i1 +} + +// CHECK-LABEL: func @xori +// CHECK: %[[false:.*]] = arith.constant false +// CHECK: return %[[false]] +func.func @xori(%arg0 : i64, %arg1 : i64) -> i1 { + %c0 = arith.constant 0 : i64 + %c7 = arith.constant 7 : i64 + %c15 = arith.constant 15 : i64 + %true = arith.constant true + + %0 = arith.minui %arg0, %c7 : i64 + %1 = arith.minui %arg1, %c15 : i64 + %2 = arith.xori %0, %1 : i64 + %3 = arith.cmpi sle, %2, %c15 : i64 + %4 = arith.xori %3, %true : i1 + func.return %4 : i1 +} + +// CHECK-LABEL: func @extui +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @extui(%arg0 : i16) -> i1 { + %ci16_max = arith.constant 0xffff : i32 + %0 = arith.extui %arg0 : i16 to i32 + %1 = arith.cmpi ule, %0, %ci16_max : i32 + func.return %1 : i1 +} + +// CHECK-LABEL: func @extsi +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @extsi(%arg0 : i16) -> i1 { + %ci16_smax = arith.constant 0x7fff : i32 + %ci16_smin = arith.constant 0xffff8000 : i32 + %0 = arith.extsi %arg0 : i16 to i32 + %1 = arith.cmpi sle, %0, %ci16_smax : i32 + %2 = arith.cmpi sge, %0, %ci16_smin : i32 + %3 = arith.andi %1, %2 : i1 + func.return %3 : i1 +} + +// CHECK-LABEL: func @trunci +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @trunci(%arg0 : i32) -> i1 { + %c-14_i32 = arith.constant -14 : i32 + %c-14_i16 = arith.constant -14 : i16 + %ci16_smin = arith.constant 0xffff8000 : i32 + %0 = arith.minsi %arg0, %c-14_i32 : i32 + %1 = arith.trunci %0 : i32 to i16 + %2 = arith.cmpi sle, %1, %c-14_i16 : i16 + %3 = arith.extsi %1 : i16 to i32 + %4 = arith.cmpi sle, %3, %c-14_i32 : i32 + %5 = arith.cmpi sge, %3, %ci16_smin : i32 + %6 = arith.andi %2, %4 : i1 + %7 = arith.andi %6, %5 : i1 + func.return %7 : i1 +} + +// CHECK-LABEL: func @index_cast +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @index_cast(%arg0 : index) -> i1 { + %ci32_smin = arith.constant 0xffffffff80000000 : i64 + %0 = arith.index_cast %arg0 : index to i32 + %1 = arith.index_cast %0 : i32 to index + %2 = arith.index_cast %ci32_smin : i64 to index + %3 = arith.cmpi sge, %1, %2 : index + func.return %3 : i1 +} + +// CHECK-LABEL: func @shli +// CHECK: %[[ret:.*]] = arith.cmpi sgt +// CHECK: return %[[ret]] +func.func @shli(%arg0 : i32, %arg1 : i1) -> i1 { + %c2 = arith.constant 2 : i32 + %c4 = arith.constant 4 : i32 + %c8 = arith.constant 8 : i32 + %c32 = arith.constant 32 : i32 + %c-1 = arith.constant -1 : i32 + %c-16 = arith.constant -16 : i32 + %0 = arith.maxsi %arg0, %c-1 : i32 + %1 = arith.minsi %0, %c2 : i32 + %2 = arith.select %arg1, %c2, %c4 : i32 + %3 = arith.shli %1, %2 : i32 + %4 = arith.cmpi sge, %3, %c-16 : i32 + %5 = arith.cmpi sle, %3, %c32 : i32 + %6 = arith.cmpi sgt, %3, %c8 : i32 + %7 = arith.andi %4, %5 : i1 + %8 = arith.andi %7, %6 : i1 + func.return %8 : i1 +} + +// CHECK-LABEL: func @shrui +// CHECK: %[[ret:.*]] = arith.cmpi uge +// CHECK: return %[[ret]] +func.func @shrui(%arg0 : i1) -> i1 { + %c2 = arith.constant 2 : i32 + %c4 = arith.constant 4 : i32 + %c8 = arith.constant 8 : i32 + %c32 = arith.constant 32 : i32 + %0 = arith.select %arg0, %c2, %c4 : i32 + %1 = arith.shrui %c32, %0 : i32 + %2 = arith.cmpi ule, %1, %c8 : i32 + %3 = arith.cmpi uge, %1, %c2 : i32 + %4 = arith.cmpi uge, %1, %c8 : i32 + %5 = arith.andi %2, %3 : i1 + %6 = arith.andi %5, %4 : i1 + func.return %6 : i1 +} + +// CHECK-LABEL: func @shrsi +// CHECK: %[[ret:.*]] = arith.cmpi slt +// CHECK: return %[[ret]] +func.func @shrsi(%arg0 : i32, %arg1 : i1) -> i1 { + %c2 = arith.constant 2 : i32 + %c4 = arith.constant 4 : i32 + %c8 = arith.constant 8 : i32 + %c32 = arith.constant 32 : i32 + %c-8 = arith.constant -8 : i32 + %c-32 = arith.constant -32 : i32 + %0 = arith.maxsi %arg0, %c-32 : i32 + %1 = arith.minsi %0, %c32 : i32 + %2 = arith.select %arg1, %c2, %c4 : i32 + %3 = arith.shrsi %1, %2 : i32 + %4 = arith.cmpi sge, %3, %c-8 : i32 + %5 = arith.cmpi sle, %3, %c8 : i32 + %6 = arith.cmpi slt, %3, %c2 : i32 + %7 = arith.andi %4, %5 : i1 + %8 = arith.andi %7, %6 : i1 + func.return %8 : i1 +} + +// CHECK-LABEL: func @no_aggressive_eq +// CHECK: %[[ret:.*]] = arith.cmpi eq +// CHECK: return %[[ret]] +func.func @no_aggressive_eq(%arg0 : index) -> i1 { + %c1 = arith.constant 1 : index + %0 = arith.andi %arg0, %c1 : index + %1 = arith.minui %arg0, %c1 : index + %2 = arith.cmpi eq, %0, %1 : index + func.return %2 : i1 +} + +// CHECK-LABEL: func @select_union +// CHECK: %[[ret:.*]] = arith.cmpi ne +// CHECK: return %[[ret]] + +func.func @select_union(%arg0 : index, %arg1 : i1) -> i1 { + %c64 = arith.constant 64 : index + %c100 = arith.constant 100 : index + %c128 = arith.constant 128 : index + %c192 = arith.constant 192 : index + %0 = arith.remui %arg0, %c64 : index + %1 = arith.addi %0, %c128 : index + %2 = arith.select %arg1, %0, %1 : index + %3 = arith.cmpi slt, %2, %c192 : index + %4 = arith.cmpi ne, %c100, %2 : index + %5 = arith.andi %3, %4 : i1 + func.return %5 : i1 +} + +// CHECK-LABEL: func @if_union +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @if_union(%arg0 : index, %arg1 : i1) -> i1 { + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %c-1 = arith.constant -1 : index + %c-4 = arith.constant -4 : index + %0 = arith.minui %arg0, %c4 : index + %1 = scf.if %arg1 -> index { + %10 = arith.muli %0, %0 : index + scf.yield %10 : index + } else { + %20 = arith.muli %0, %c-1 : index + scf.yield %20 : index + } + %2 = arith.cmpi sle, %1, %c16 : index + %3 = arith.cmpi sge, %1, %c-4 : index + %4 = arith.andi %2, %3 : i1 + func.return %4 : i1 +} + +// CHECK-LABEL: func @branch_union +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @branch_union(%arg0 : index, %arg1 : i1) -> i1 { + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %c-1 = arith.constant -1 : index + %c-4 = arith.constant -4 : index + %0 = arith.minui %arg0, %c4 : index + cf.cond_br %arg1, ^bb1, ^bb2 +^bb1 : + %1 = arith.muli %0, %0 : index + cf.br ^bb3(%1 : index) +^bb2 : + %2 = arith.muli %0, %c-1 : index + cf.br ^bb3(%2 : index) +^bb3(%3 : index) : + %4 = arith.cmpi sle, %3, %c16 : index + %5 = arith.cmpi sge, %3, %c-4 : index + %6 = arith.andi %4, %5 : i1 + func.return %6 : i1 +} + +// CHECK-LABEL: func @loop_bound_not_inferred_with_branch +// CHECK-DAG: %[[min:.*]] = arith.cmpi sge +// CHECK-DAG: %[[max:.*]] = arith.cmpi slt +// CHECK-DAG: %[[ret:.*]] = arith.andi %[[min]], %[[max]] +// CHECK: return %[[ret]] +func.func @loop_bound_not_inferred_with_branch(%arg0 : index, %arg1 : i1) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %0 = arith.minui %arg0, %c4 : index + cf.br ^bb2(%c0 : index) +^bb1(%1 : index) : + %2 = arith.addi %1, %c1 : index + cf.br ^bb2(%2 : index) +^bb2(%3 : index): + %4 = arith.cmpi ult, %3, %c4 : index + cf.cond_br %4, ^bb1(%3 : index), ^bb3(%3 : index) +^bb3(%5 : index) : + %6 = arith.cmpi sge, %5, %c0 : index + %7 = arith.cmpi slt, %5, %c4 : index + %8 = arith.andi %6, %7 : i1 + func.return %8 : i1 +} + diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir --- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir +++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir @@ -101,16 +101,16 @@ func.return %0 : index } -// CHECK-LABEL: func @propagate_across_while_loop() -func.func @propagate_across_while_loop() -> index { +// CHECK-LABEL: func @propagate_across_while_loop_false() +func.func @propagate_across_while_loop_false() -> index { // CHECK-DAG: %[[C0:.*]] = "test.constant"() {value = 0 // CHECK-DAG: %[[C1:.*]] = "test.constant"() {value = 1 %0 = test.with_bounds { umin = 0 : index, umax = 0 : index, smin = 0 : index, smax = 0 : index } %1 = scf.while : () -> index { - %true = arith.constant true + %false = arith.constant false // CHECK: scf.condition(%{{.*}}) %[[C0]] - scf.condition(%true) %0 : index + scf.condition(%false) %0 : index } do { ^bb0(%i1: index): scf.yield @@ -119,3 +119,42 @@ %2 = test.increment %1 return %2 : index } + +// CHECK-LABEL: func @propagate_across_while_loop +func.func @propagate_across_while_loop(%arg0 : i1) -> index { + // CHECK-DAG: %[[C0:.*]] = "test.constant"() {value = 0 + // CHECK-DAG: %[[C1:.*]] = "test.constant"() {value = 1 + %0 = test.with_bounds { umin = 0 : index, umax = 0 : index, + smin = 0 : index, smax = 0 : index } + %1 = scf.while : () -> index { + // CHECK: scf.condition(%{{.*}}) %[[C0]] + scf.condition(%arg0) %0 : index + } do { + ^bb0(%i1: index): + scf.yield + } + // CHECK: return %[[C1]] + %2 = test.increment %1 + return %2 : index +} + +// CHECK-LABEL: func @dont_propagate_across_infinite_loop() +func.func @dont_propagate_across_infinite_loop() -> index { + // CHECK: %[[C0:.*]] = "test.constant"() {value = 0 + %0 = test.with_bounds { umin = 0 : index, umax = 0 : index, + smin = 0 : index, smax = 0 : index } + // CHECK: %[[loopRes:.*]] = scf.while + %1 = scf.while : () -> index { + %true = arith.constant true + // CHECK: scf.condition(%{{.*}}) %[[C0]] + scf.condition(%true) %0 : index + } do { + ^bb0(%i1: index): + scf.yield + } + // CHECK: %[[ret:.*]] = test.reflect_bounds %[[loopRes]] + %2 = test.reflect_bounds %1 + // CHECK: return %[[ret]] + return %2 : index +} +