diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h --- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h @@ -12,6 +12,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -11,6 +11,7 @@ include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" @@ -49,7 +50,8 @@ // Base class for integer binary operations. class Arith_IntBinaryOp traits = []> : - Arith_BinaryOp, + Arith_BinaryOp]>, Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>, Results<(outs SignlessIntegerLike:$result)>; @@ -87,7 +89,8 @@ // Cast from an integer type to another integer type. class Arith_IToICastOp traits = []> : Arith_CastOp; + SignlessFixedWidthIntegerLike, + traits # [DeclareOpInterfaceMethods]>; // Cast from an integer type to a floating point type. class Arith_IToFCastOp traits = []> : Arith_CastOp; @@ -104,7 +107,8 @@ class Arith_CompareOp traits = []> : Arith_Op]> { + "lhs", "result", "::getI1SameShape($_self)">, + DeclareOpInterfaceMethods]> { let results = (outs BoolLike:$result); let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; @@ -124,6 +128,7 @@ def Arith_ConstantOp : Op, + DeclareOpInterfaceMethods, TypesMatchWith< "result and attribute have the same type", "value", "result", "$_self">]> { @@ -973,7 +978,8 @@ "signless-integer-like or memref of signless-integer">; def Arith_IndexCastOp : Arith_CastOp<"index_cast", IndexCastTypeConstraint, - IndexCastTypeConstraint> { + IndexCastTypeConstraint, + [DeclareOpInterfaceMethods]> { let summary = "cast between index and integer types"; let description = [{ Casts between scalar or vector integers and corresponding 'index' scalar or @@ -1166,7 +1172,8 @@ //===----------------------------------------------------------------------===// def SelectOp : Arith_Op<"select", [ - AllTypesMatch<["true_value", "false_value", "result"]> + AllTypesMatch<["true_value", "false_value", "result"]>, + DeclareOpInterfaceMethods, ] # ElementwiseMappable.traits> { let summary = "select operation"; let description = [{ @@ -1206,7 +1213,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; - + // FIXME: Switch this to use the declarative assembly format. let hasCustomAssemblyFormat = 1; } diff --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRArithmetic ArithmeticOps.cpp ArithmeticDialect.cpp + InferIntRangeInterfaceImpls.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic @@ -14,6 +15,7 @@ LINK_LIBS PUBLIC MLIRDialect + MLIRInferIntRangeInterface MLIRInferTypeOpInterface MLIRIR ) diff --git a/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp @@ -0,0 +1,802 @@ +//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "int-range-analysis" + +using namespace mlir; +using namespace mlir::arith; + +/// Get the bitwidth of the attributes holding constants of type `type`. +static unsigned int getAttrBitwidth(Type type) { + if (type.isIndex()) + return IndexType::kInternalStorageBitWidth; + return type.getIntOrFloatBitWidth(); +} + +using SetRangeFn = llvm::function_ref; + +/// Function that evaluates the result of doing something on arithmetic +/// constants and returns None on overflow. +using ConstArithFn = + llvm::function_ref(const APInt &, const APInt &)>; + +/// A [min, max] pair that can be signed or unsigned +using IntAttrPair = std::pair, Optional>; + +/// If both `left` and `right` are defined, return the result of +/// `op(left.getValue(), right.getValue()`, otherwise, return None. +static Optional compute(ConstArithFn op, const Optional &left, + const Optional &right) { + if (!left || !right) + return {}; + return op(left.getValue(), right.getValue()); +} + +/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, +/// If either computation overflows, make the result unbounded. +static IntAttrPair computeBoundsBy(ConstArithFn op, + const Optional &minLeft, + const Optional &minRight, + const Optional &maxLeft, + const Optional &maxRight) { + Optional min, max; + if (minLeft && minRight) { + min = compute(op, minLeft, minRight); + if (!min) + return {{}, {}}; + } + if (maxLeft && maxRight) { + max = compute(op, maxLeft, maxRight); + if (!max) + return {{}, {}}; + } + return {min, max}; +} + +/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, +/// ignoring unbounded values. Returns (null, null) if `op` overflows. +static IntAttrPair minMaxBy(ConstArithFn op, ArrayRef> lhs, + ArrayRef> rhs, + bool signedCmp = false) { + Optional min, max; + for (Optional left : lhs) { + for (Optional right : rhs) { + if (!left || !right) { + // A missing lower or upper bound should be accounted for by the parent + // function + continue; + } + Optional thisResult = compute(op, left, right); + if (!thisResult) { + return {{}, {}}; + } + if (min) { + min = (signedCmp ? thisResult->slt(*min) : thisResult->ult(*min)) + ? thisResult + : min; + } else { + min = thisResult; + } + + if (max) { + max = (signedCmp ? thisResult->sgt(*max) : thisResult->ugt(*max)) + ? thisResult + : max; + } else { + max = thisResult; + } + } + } + return {min, max}; +} + +void arith::ConstantOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + // Return null for non-scalar integer constants + auto constAttr = getValue().dyn_cast_or_null(); + Optional value; + if (constAttr) + value = constAttr.getValue(); + setResultRange(getResult(), IntRangeAttrs::range(value, value)); +} + +void arith::AddIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + ConstArithFn uadd = [](auto &a, auto &b) -> Optional { + bool overflowed = false; + APInt result = a.uadd_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + ConstArithFn sadd = [](auto &a, auto &b) -> Optional { + bool overflowed = false; + APInt result = a.sadd_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + + auto urange = + computeBoundsBy(uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax()); + auto srange = + computeBoundsBy(sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax()); + setResultRange(getResult(), {urange, srange}); +} + +void arith::SubIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + ConstArithFn usub = [](auto &a, auto &b) -> Optional { + bool overflowed = false; + APInt result = a.usub_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + ConstArithFn ssub = [](auto &a, auto &b) -> Optional { + bool overflowed = false; + APInt result = a.ssub_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + auto urange = + computeBoundsBy(usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin()); + auto srange = + computeBoundsBy(ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin()); + setResultRange(getResult(), {urange, srange}); +} + +void arith::MulIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + // Determine what bounds we can impose on signed multiplication. + bool noNegatives = (lhs.smin() && rhs.smin() && lhs.smin()->isNonNegative() && + rhs.smin()->isNonNegative()); + bool canBoundBelow = + (lhs.smin() && rhs.smin() && (noNegatives || (lhs.smax() && rhs.smax()))); + bool canBoundAbove = + (lhs.smax() && rhs.smax() && (noNegatives || (lhs.smin() && rhs.smin()))); + + ConstArithFn umul = [](auto &a, auto &b) -> Optional { + bool overflowed = false; + APInt result = a.umul_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + ConstArithFn smul = [](auto &a, auto &b) -> Optional { + bool overflowed = false; + APInt result = a.smul_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + + Optional umin, umax, smin, smax; + std::tie(umin, umax) = + minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*signedCmp=*/false); + std::tie(smin, smax) = + minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, + /*signedCmp=*/true); + + if (!lhs.umin() || !rhs.umin()) + umin = {}; + if (!lhs.umax() || !rhs.umax()) + umax = {}; + if (!canBoundBelow) + smin = {}; + if (!canBoundAbove) + smax = {}; + setResultRange(getResult(), {umin, umax, smin, smax}); +} + +/// Fix up division results (ex. for ceiling and floor), returning an APInt +/// if there has been no overflow +using DivisionFixupFn = llvm::function_ref( + const APInt &lhs, const APInt &rhs, const APInt &result)>; + +static IntRangeAttrs inferDivUIRange(Type resultType, const IntRangeAttrs &lhs, + const IntRangeAttrs &rhs, + DivisionFixupFn fixup) { + Optional lhsMin = lhs.umin(); + Optional lhsMax = lhs.umax(); + Optional rhsMin = rhs.umin(); + Optional rhsMax = rhs.umax(); + + unsigned int bitwidth = getAttrBitwidth(resultType); + if (rhsMin && !rhsMin->isZero()) { + if (!rhsMax) // Bound divisor above by 0xffff...fff to get lower bound of 0 + rhsMax = APInt::getAllOnesValue(bitwidth); + auto udiv = [&fixup](auto &a, auto &b) -> Optional { + return fixup(a, b, a.udiv(b)); + }; + auto urange = minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*signedCmp=*/false); + if (!lhsMin) + urange.first = {}; + if (!lhsMax) + urange.second = {}; + + return IntRangeAttrs::fromUnsigned(urange); + } + // Otherwise, it's possible we might divide by 0. + return {}; +} + +void arith::DivUIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + Type resultType = getResult().getType(); + setResultRange(getResult(), + inferDivUIRange(resultType, argRanges[0], argRanges[1], + [](auto &lhs, auto &rhs, auto &result) { + return result; + })); +} + +static IntRangeAttrs inferDivSIRange(Type resultType, const IntRangeAttrs &lhs, + const IntRangeAttrs &rhs, + DivisionFixupFn fixup) { + Optional lhsMin = lhs.smin(); + Optional lhsMax = lhs.smax(); + Optional rhsMin = rhs.smin(); + Optional rhsMax = rhs.smax(); + bool canBoundBelow = rhsMin && rhsMin->isStrictlyPositive(); + bool canBoundAbove = rhsMax && rhsMax->isNegative(); + bool canDivide = canBoundBelow || canBoundAbove; + + if (canDivide) { + unsigned int bitwidth = getAttrBitwidth(resultType); + // Unbounded below + negative upper bound -> lower bound = INT_MIN + if (!rhsMin) + rhsMin = APInt::getSignedMinValue(bitwidth); + // Unbounded above + positive lower bound -> upper bound = INT_MAX + if (!rhsMax) + rhsMax = APInt::getSignedMaxValue(bitwidth); + auto sdiv = [&fixup](auto &a, auto &b) -> Optional { + bool overflowed = false; + APInt result = a.sdiv_ov(b, overflowed); + if (overflowed) + return {}; + return fixup(a, b, result); + }; + auto srange = minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*signedCmp=*/true); + if (!lhsMin) + srange.first = {}; + if (!lhsMax) + srange.second = {}; + + return IntRangeAttrs::fromSigned(srange); + } + return {}; +} + +void arith::DivSIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + setResultRange(getResult(), + inferDivSIRange(getResult().getType(), argRanges[0], + argRanges[1], + [](auto &lhs, auto &rhs, auto &result) { + return result; + })); +} + +void arith::CeilDivUIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + DivisionFixupFn ceilDivUIFix = [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> Optional { + if (!lhs.urem(rhs).isZero()) { + bool overflowed = false; + APInt corrected = + result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); + if (overflowed) + return {}; + return corrected; + } + return result; + }; + Type resultType = getResult().getType(); + setResultRange(getResult(), + inferDivUIRange(resultType, lhs, rhs, ceilDivUIFix)); +} + +void arith::CeilDivSIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + Type resultType = getResult().getType(); + DivisionFixupFn ceilDivSIFix = [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> Optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); + if (overflowed) + return {}; + return corrected; + } + return result; + }; + setResultRange(getResult(), + inferDivSIRange(resultType, lhs, rhs, ceilDivSIFix)); +} + +void arith::FloorDivSIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + Type resultType = getResult().getType(); + DivisionFixupFn floorDivSIFix = [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> Optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); + if (overflowed) + return {}; + return corrected; + } + return result; + }; + setResultRange(getResult(), + inferDivSIRange(resultType, lhs, rhs, floorDivSIFix)); +} + +void arith::RemUIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + Optional rhsMin = rhs.umin(); + Optional rhsMax = rhs.umax(); + Optional umin, umax; + + if (rhsMin && rhsMax && !rhsMin->isZero()) { + APInt maxDivisor = *rhsMax; + umin = APInt::getZero(maxDivisor.getBitWidth()); + umax = maxDivisor - 1; + // Special case: sweeping out a contiguous range in N/[modulus] + Optional lhsMin = lhs.umin(); + Optional lhsMax = lhs.umax(); + if (lhsMin && lhsMax && rhsMin == rhsMax) { + APInt minDividend = *lhsMin; + APInt maxDividend = *lhsMax; + if ((maxDividend - minDividend).ult(maxDivisor)) { + APInt minRem = minDividend.urem(maxDivisor); + APInt maxRem = maxDividend.urem(maxDivisor); + if (minRem.ule(maxRem)) { + umin = minRem; + umax = maxRem; + } + } + } + } + setResultRange(getResult(), IntRangeAttrs::fromUnsigned(umin, umax)); +} + +void arith::RemSIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + Optional lhsMin = lhs.smin(); + Optional lhsMax = lhs.smax(); + Optional rhsMin = rhs.smin(); + Optional rhsMax = rhs.smax(); + + Optional smin, smax; + // No bounds if zero could be a divisor. + bool canBound = rhsMax && rhsMin && + (rhsMin->isStrictlyPositive() || rhsMax->isNegative()); + if (canBound) { + APInt maxDivisor = rhsMin->isStrictlyPositive() ? *rhsMax : rhsMin->abs(); + bool canNegativeDividend = !(lhsMin && lhsMin->isNonNegative()); + bool canPositiveDividend = !(lhsMax && lhsMax->isNonPositive()); + APInt zero = APInt::getZero(maxDivisor.getBitWidth()); + APInt maxPositiveResult = maxDivisor - 1; + APInt minNegativeResult = -maxPositiveResult; + smin = canNegativeDividend ? minNegativeResult : zero; + smax = canPositiveDividend ? maxPositiveResult : zero; + // Special case: sweeping out a contiguous range in N/[modulus] + if (lhsMin && lhsMax && rhsMin == rhsMax) { + APInt minDividend = *lhsMin; + APInt maxDividend = *lhsMax; + if ((maxDividend - minDividend).ult(maxDivisor)) { + APInt minRem = minDividend.srem(maxDivisor); + APInt maxRem = maxDividend.srem(maxDivisor); + if (minRem.sle(maxRem)) { + smin = minRem; + smax = maxRem; + } + } + } + } + setResultRange(getResult(), IntRangeAttrs::fromSigned(smin, smax)); +} + +/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, +/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits +/// that both bonuds have in common. This gives us a consertive approximation +/// for what values can be passed to bitwise operations. This will widen missing +/// bounds to all zeroes / all ones so we can handle [unbounded] & 0xff => [0, +/// 255]. +static IntAttrPair widenBitwiseBounds(Type resultType, + const IntRangeAttrs &bound) { + unsigned int bitwidth = getAttrBitwidth(resultType); + APInt leftVal = bound.umin().getValueOr(APInt::getZero(bitwidth)); + APInt rightVal = bound.umax().getValueOr(APInt::getAllOnesValue(bitwidth)); + unsigned int differingBits = + bitwidth - (leftVal ^ rightVal).countLeadingZeros(); + leftVal.clearLowBits(differingBits); + rightVal.setLowBits(differingBits); + return {leftVal, rightVal}; +} + +void arith::AndIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + Type resultType = getResult().getType(); + Optional lhsZeros, lhsOnes, rhsZeros, rhsOnes; + std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(resultType, argRanges[0]); + std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(resultType, argRanges[1]); + setResultRange( + getResult(), + IntRangeAttrs::fromUnsigned(minMaxBy( + [](auto &a, auto &b) -> Optional { return a & b; }, + {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, /*signedCmp=*/false))); +} + +void arith::OrIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + Type resultType = getResult().getType(); + Optional lhsZeros, lhsOnes, rhsZeros, rhsOnes; + std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(resultType, argRanges[0]); + std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(resultType, argRanges[1]); + setResultRange( + getResult(), + IntRangeAttrs::fromUnsigned(minMaxBy( + [](auto &a, auto &b) -> Optional { return a | b; }, + {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, /*signedCmp=*/false))); +} + +void arith::XOrIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + Type resultType = getResult().getType(); + Optional lhsZeros, lhsOnes, rhsZeros, rhsOnes; + std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(resultType, argRanges[0]); + std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(resultType, argRanges[1]); + setResultRange( + getResult(), + IntRangeAttrs::fromUnsigned(minMaxBy( + [](auto &a, auto &b) -> Optional { return a ^ b; }, + {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, /*signedCmp=*/false))); +} + +void arith::MaxSIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + Optional smin; + // Take the largest lower bound (if any). + if (lhs.smin() && rhs.smin()) + smin = lhs.smin()->sgt(*rhs.smin()) ? lhs.smin() : rhs.smin(); + else if (lhs.smin()) + smin = lhs.smin(); + else if (rhs.smin()) + smin = rhs.smin(); + + // If both upper bounds are present, take their max, be unbounded otherwise. + Optional smax; + if (lhs.smax() && rhs.smax()) + smax = lhs.smax()->sgt(*rhs.smax()) ? lhs.smax() : rhs.smax(); + setResultRange(getResult(), IntRangeAttrs::fromSigned(smin, smax)); +} + +void arith::MaxUIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + Optional umin; + // Take the largest lower bound (if any). + if (lhs.umin() && rhs.umin()) + umin = lhs.umin()->ugt(*rhs.umin()) ? lhs.umin() : rhs.umin(); + else if (lhs.umin()) + umin = lhs.umin(); + else if (rhs.umin()) + umin = rhs.umin(); + + // If both upper bounds are present, take their max, be unbounded otherwise. + Optional umax; + if (lhs.umax() && rhs.umax()) + umax = lhs.umax()->ugt(*rhs.umax()) ? lhs.umax() : rhs.umax(); + setResultRange(getResult(), IntRangeAttrs::fromUnsigned(umin, umax)); +} + +void arith::MinSIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + Optional smin; + // If both lower bounds are present, take their minimum. + if (lhs.smin() && rhs.smin()) + smin = lhs.smin()->slt(*rhs.smin()) ? lhs.smin() : rhs.smin(); + + // For upper bounds, take the smallest (with absent -> +infinity). + Optional smax; + if (lhs.smax() && rhs.smax()) + smax = lhs.smax()->slt(*rhs.smax()) ? lhs.smax() : rhs.smax(); + else if (lhs.smax()) + smax = lhs.smax(); + else if (rhs.smax()) + smax = rhs.smax(); + setResultRange(getResult(), IntRangeAttrs::fromSigned(smin, smax)); +} + +void arith::MinUIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + Optional umin; + // If both lower bounds are present, take their minimum. + if (lhs.umin() && rhs.umin()) + umin = lhs.umin()->ult(*rhs.umin()) ? lhs.umin() : rhs.umin(); + + // For upper bounds, take the smallest (with absent -> +infinity). + Optional umax; + if (lhs.umax() && rhs.umax()) + umax = lhs.umax()->ult(*rhs.umax()) ? lhs.umax() : rhs.umax(); + else if (lhs.umax()) + umax = lhs.umax(); + else if (rhs.umax()) + umax = rhs.umax(); + setResultRange(getResult(), IntRangeAttrs::fromUnsigned(umin, umax)); +} + +void arith::ExtUIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + unsigned int srcWidth = sourceType.getIntOrFloatBitWidth(); + unsigned int destWidth = destType.getIntOrFloatBitWidth(); + APInt umin = argRanges[0].umin().getValueOr(APInt::getZero(srcWidth)); + APInt umax = argRanges[0].umax().getValueOr(APInt::getAllOnesValue(srcWidth)); + + umin = umin.zext(destWidth); + umax = umax.zext(destWidth); + setResultRange(getResult(), IntRangeAttrs::fromUnsigned(umin, umax)); +} + +static IntRangeAttrs extSIRange(const IntRangeAttrs &range, Type sourceType, + Type destType) { + unsigned int srcWidth = getAttrBitwidth(sourceType); + unsigned int destWidth = getAttrBitwidth(destType); + APInt smin = range.smin().getValueOr(APInt::getSignedMinValue(srcWidth)); + APInt smax = range.smax().getValueOr(APInt::getSignedMaxValue(srcWidth)); + + smin = smin.sext(destWidth); + smax = smax.sext(destWidth); + return IntRangeAttrs::fromSigned(smin, smax); +} + +void arith::ExtSIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + setResultRange(getResult(), extSIRange(argRanges[0], sourceType, destType)); +} + +static IntRangeAttrs truncIRange(const IntRangeAttrs &range, Type destType) { + unsigned int destWidth = getAttrBitwidth(destType); + auto trunci = [destWidth](const APInt &val) -> APInt { + return val.trunc(destWidth); + }; + Optional umin = range.umin().map(trunci); + Optional umax = range.umax().map(trunci); + Optional smin = range.smin().map(trunci); + Optional smax = range.smax().map(trunci); + return {umin, umax, smin, smax}; +} + +void arith::TruncIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + Type destType = getResult().getType(); + setResultRange(getResult(), truncIRange(argRanges[0], destType)); +} + +void arith::IndexCastOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + unsigned int srcWidth = getAttrBitwidth(sourceType); + unsigned int destWidth = getAttrBitwidth(destType); + + if (srcWidth < destWidth) + setResultRange(getResult(), extSIRange(argRanges[0], sourceType, destType)); + else if (srcWidth > destWidth) + setResultRange(getResult(), truncIRange(argRanges[0], destType)); + else + setResultRange(getResult(), argRanges[0]); +} + +bool isStaticallyTrue(arith::CmpIPredicate pred, const IntRangeAttrs &lhs, + const IntRangeAttrs &rhs) { + switch (pred) { + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::slt: + return (lhs.smax() && rhs.smin() && + applyCmpPredicate(pred, *lhs.smax(), *rhs.smin())); + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::ult: + return (lhs.umax() && rhs.umin() && + applyCmpPredicate(pred, *lhs.umax(), *rhs.umin())); + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::sgt: + return (lhs.smin() && rhs.smax() && + applyCmpPredicate(pred, *lhs.smin(), *rhs.smax())); + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::ugt: + return (lhs.umin() && rhs.umax() && + applyCmpPredicate(pred, *lhs.umin(), *rhs.umax())); + case arith::CmpIPredicate::eq: { + Optional lhsConst = lhs.getConstantValue(); + Optional rhsConst = rhs.getConstantValue(); + return lhsConst && rhsConst && lhsConst == rhsConst; + } + case arith::CmpIPredicate::ne: { + // While equality requires that there is an interpration of the preceeding + // computations that produces equal constants, whether that be signed or + // unsigned, statically determining inequality requires that neither + // interpretation produce potentially overlapping ranges. + bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) || + isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs); + bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) || + isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs); + return sne && une; + } + } + return false; +} + +void arith::CmpIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + arith::CmpIPredicate pred = getPredicate(); + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + Optional value; + if (isStaticallyTrue(pred, lhs, rhs)) + value = APInt::getAllOnesValue(1); + else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) + value = APInt::getZero(1); + + setResultRange(getResult(), IntRangeAttrs::range(value, value)); +} + +void arith::CmpFOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + // Can't infer anything about floats. + setResultRange(getResult(), {}); +} + +void arith::SelectOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + Optional mbCondVal = argRanges[0].getConstantValue(); + + if (mbCondVal) { + if (mbCondVal->isZero()) + setResultRange(getResult(), argRanges[2]); + else + setResultRange(getResult(), argRanges[1]); + return; + } + setResultRange(getResult(), IntRangeAttrs::join(argRanges[1], argRanges[2])); +} + +void arith::ShLIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + if (!rhs.umin() || !rhs.umax()) { + setResultRange(getResult(), {}); + return; + } + + ConstArithFn shl = [](auto &l, auto &r) -> Optional { + if (r.uge(r.getBitWidth())) + return {}; + return l.shl(r); + }; + Optional umin, umax, smin, smax; + std::tie(umin, umax) = + minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*signedCmp=*/false); + std::tie(smin, smax) = + minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, + /*signedCmp=*/true); + if (!lhs.umin()) + umin = {}; + if (!lhs.umax()) + umax = {}; + if (!lhs.smin()) + smin = {}; + if (!lhs.smax()) + smax = {}; + + setResultRange(getResult(), {umin, umax, smin, smax}); +} + +void arith::ShRUIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + if (!rhs.umin() || !rhs.umax()) { + setResultRange(getResult(), {}); + return; + } + + ConstArithFn lshr = [](auto &l, auto &r) -> Optional { + if (r.uge(r.getBitWidth())) + return {}; + return l.lshr(r); + }; + Optional umin, umax; + std::tie(umin, umax) = + minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*signedCmp=*/false); + if (!lhs.umin()) + umin = {}; + if (!lhs.umax()) + umax = {}; + + setResultRange(getResult(), IntRangeAttrs::fromUnsigned(umin, umax)); +} + +void arith::ShRSIOp::inferResultRanges(ArrayRef argRanges, + SetRangeFn setResultRange) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + if (!rhs.umin() || !rhs.umax()) { + setResultRange(getResult(), {}); + return; + } + + ConstArithFn ashr = [](auto &l, auto &r) -> Optional { + if (r.uge(r.getBitWidth())) + return {}; + return l.ashr(r); + }; + Optional smin, smax; + std::tie(smin, smax) = minMaxBy(ashr, {lhs.smin(), lhs.smax()}, + {rhs.umin(), rhs.umax()}, /*signedCmp=*/true); + if (!lhs.smin()) + smin = {}; + if (!lhs.smax()) + smax = {}; + + setResultRange(getResult(), IntRangeAttrs::fromSigned(smin, smax)); +} diff --git a/mlir/test/Dialect/Arithmetic/fold-inferred-constants.mlir b/mlir/test/Dialect/Arithmetic/fold-inferred-constants.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arithmetic/fold-inferred-constants.mlir @@ -0,0 +1,626 @@ +// RUN: mlir-opt -arith-fold-inferred-constants -canonicalize %s | FileCheck %s + +// CHECK-LABEL: func @add_min_max +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: return %[[c3]] +func @add_min_max(%a: index, %b: index) -> index { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = arith.minsi %a, %c1 : index + %1 = arith.maxsi %0, %c1 : index + %2 = arith.minui %b, %c2 : index + %3 = arith.maxui %2, %c2 : index + %4 = arith.addi %1, %3 : index + return %4 : index +} + +// CHECK-LABEL: func @add_lower_bound +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @add_lower_bound(%a : i32, %b : i32) -> i1 { + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %0 = arith.maxsi %a, %c1 : i32 + %1 = arith.maxsi %b, %c1 : i32 + %2 = arith.addi %0, %1 : i32 + %3 = arith.cmpi sge, %2, %c2 : i32 + %4 = arith.cmpi uge, %2, %c2 : i32 + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @sub_signed_vs_unsigned +// CHECK-NOT: arith.cmpi sle +// CHECK: %[[unsigned:.*]] = arith.cmpi ule +// CHECK: return %[[unsigned]] : i1 +func @sub_signed_vs_unsigned(%v : i64) -> i1 { + %c0 = arith.constant 0 : i64 + %c2 = arith.constant 2 : i64 + %0 = arith.minsi %v, %c2 : i64 + %1 = arith.subi %0, %c2 : i64 + %2 = arith.cmpi sle, %1, %c0 : i64 + %3 = arith.cmpi ule, %1, %c0 : i64 + %4 = arith.andi %2, %3 : i1 + return %4 : i1 +} + +// CHECK-LABEL: func @multiply_negatives +// CHECK: %[[false:.*]] = arith.constant false +// CHECK: return %[[false]] +func @multiply_negatives(%a : index, %b : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c_1 = arith.constant -1 : index + %c_2 = arith.constant -2 : index + %c_4 = arith.constant -4 : index + %c_12 = arith.constant -12 : index + %0 = arith.maxsi %a, %c2 : index + %1 = arith.minsi %0, %c3 : index + %2 = arith.minsi %b, %c_1 : index + %3 = arith.maxsi %2, %c_4 : index + %4 = arith.muli %1, %3 : index + %5 = arith.cmpi slt, %4, %c_12 : index + %6 = arith.cmpi slt, %c_1, %4 : index + %7 = arith.ori %5, %6 : i1 + return %7 : i1 +} + +// CHECK-LABEL: func @multiply_unsigned_bounds +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @multiply_unsigned_bounds(%a : i16, %b : i16) -> i1 { + %c0 = arith.constant 0 : i16 + %c4 = arith.constant 4 : i16 + %c_mask = arith.constant 0x3fff : i16 + %c_bound = arith.constant 0xfffc : i16 + %0 = arith.andi %a, %c_mask : i16 + %1 = arith.minui %b, %c4 : i16 + %2 = arith.muli %0, %1 : i16 + %3 = arith.cmpi uge, %2, %c0 : i16 + %4 = arith.cmpi ule, %2, %c_bound : i16 + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: @for_loop_with_increasing_arg +// CHECK: %[[ret:.*]] = arith.cmpi ule +// CHECK: return %[[ret]] +func @for_loop_with_increasing_arg() -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %0 = scf.for %arg0 = %c0 to %c4 step %c1 iter_args(%arg1 = %c0) -> index { + %10 = arith.addi %arg0, %arg1 : index + scf.yield %10 : index + } + %1 = arith.cmpi ule, %0, %c16 : index + return %1 : i1 +} + +// CHECK-LABEL: @for_loop_with_constant_result +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @for_loop_with_constant_result() -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %true = arith.constant true + %0 = scf.for %arg0 = %c0 to %c4 step %c1 iter_args(%arg1 = %true) -> i1 { + %10 = arith.cmpi ule, %arg0, %c4 : index + %11 = arith.andi %10, %arg1 : i1 + scf.yield %11 : i1 + } + return %0 : i1 +} + +// CHECK-LABEL: func @div_bounds_positive +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @div_bounds_positive(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %0 = arith.maxsi %arg0, %c2 : index + %1 = arith.divsi %c4, %0 : index + %2 = arith.divui %c4, %0 : index + + %3 = arith.cmpi sge, %1, %c0 : index + %4 = arith.cmpi sle, %1, %c2 : index + %5 = arith.cmpi sge, %2, %c0 : index + %6 = arith.cmpi sle, %1, %c2 : index + + %7 = arith.andi %3, %4 : i1 + %8 = arith.andi %7, %5 : i1 + %9 = arith.andi %8, %6 : i1 + return %9 : i1 +} + +// CHECK-LABEL: func @div_bounds_negative +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @div_bounds_negative(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c_2 = arith.constant -2 : index + %c4 = arith.constant 4 : index + %0 = arith.minsi %arg0, %c_2 : index + %1 = arith.divsi %c4, %0 : index + %2 = arith.divui %c4, %0 : index + + %3 = arith.cmpi sle, %1, %c0 : index + %4 = arith.cmpi sge, %1, %c_2 : index + %5 = arith.cmpi eq, %2, %c0 : index + + %7 = arith.andi %3, %4 : i1 + %8 = arith.andi %7, %5 : i1 + return %8 : i1 +} + +// CHECK-LABEL: func @div_zero_undefined +// CHECK: %[[ret:.*]] = arith.cmpi ule +// CHECK: return %[[ret]] +func @div_zero_undefined(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %0 = arith.andi %arg0, %c1 : index + %1 = arith.divui %c4, %0 : index + %2 = arith.cmpi ule, %1, %c4 : index + return %2 : i1 +} + +// CHECK-LABEL: func @ceil_divui +// CHECK: %[[ret:.*]] = arith.cmpi eq +// CHECK: return %[[ret]] +func @ceil_divui(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + + %0 = arith.minui %arg0, %c3 : index + %1 = arith.maxui %0, %c1 : index + %2 = arith.ceildivui %1, %c4 : index + %3 = arith.cmpi eq, %2, %c1 : index + + %4 = arith.maxui %0, %c0 : index + %5 = arith.ceildivui %4, %c4 : index + %6 = arith.cmpi eq, %5, %c1 : index + %7 = arith.andi %3, %6 : i1 + return %7 : i1 +} + +// CHECK-LABEL: func @ceil_divsi +// CHECK: %[[ret:.*]] = arith.cmpi eq +// CHECK: return %[[ret]] +func @ceil_divsi(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c3 : index + %1 = arith.maxsi %0, %c1 : index + %2 = arith.ceildivsi %1, %c4 : index + %3 = arith.cmpi eq, %2, %c1 : index + %4 = arith.ceildivsi %1, %c-4 : index + %5 = arith.cmpi eq, %4, %c0 : index + %6 = arith.andi %3, %5 : i1 + + %7 = arith.maxsi %0, %c0 : index + %8 = arith.ceildivsi %7, %c4 : index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = arith.andi %6, %9 : i1 + return %10 : i1 +} + +// CHECK-LABEL: func @floor_divsi +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @floor_divsi(%arg0 : index) -> i1 { + %c4 = arith.constant 4 : index + %c-1 = arith.constant -1 : index + %c-3 = arith.constant -3 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c-1 : index + %1 = arith.maxsi %0, %c-4 : index + %2 = arith.floordivsi %1, %c4 : index + %3 = arith.cmpi eq, %2, %c-1 : index + return %3 : i1 +} + +// CHECK-LABEL: func @remui_base +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @remui_base(%arg0 : index, %arg1 : index ) -> i1 { + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + %0 = arith.minui %arg1, %c4 : index + %1 = arith.maxui %0, %c2 : index + %2 = arith.remui %arg0, %1 : index + %3 = arith.cmpi ult, %2, %c4 : index + return %3 : i1 +} + +// CHECK-LABEL: func @remsi_base +// CHECK: %[[ret:.*]] = arith.cmpi sge +// CHECK: return %[[ret]] +func @remsi_base(%arg0 : index, %arg1 : index ) -> i1 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c-4 = arith.constant -4 : index + %true = arith.constant true + + %0 = arith.minsi %arg1, %c4 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.remsi %arg0, %1 : index + %3 = arith.cmpi sgt, %2, %c-4 : index + %4 = arith.cmpi slt, %2, %c4 : index + %5 = arith.cmpi sge, %2, %c0 : index + %6 = arith.andi %3, %4 : i1 + %7 = arith.andi %5, %6 : i1 + return %7 : i1 +} + +// CHECK-LABEL: func @remsi_positive +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @remsi_positive(%arg0 : index, %arg1 : index ) -> i1 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %true = arith.constant true + + %0 = arith.minsi %arg1, %c4 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.maxsi %arg0, %c0 : index + %3 = arith.remsi %2, %1 : index + %4 = arith.cmpi sge, %3, %c0 : index + %5 = arith.cmpi slt, %3, %c4 : index + %6 = arith.andi %4, %5 : i1 + return %6 : i1 +} + +// CHECK-LABEL: func @remui_restricted +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @remui_restricted(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + + %0 = arith.minui %arg0, %c3 : index + %1 = arith.maxui %0, %c2 : index + %2 = arith.remui %1, %c4 : index + %3 = arith.cmpi ule, %2, %c3 : index + %4 = arith.cmpi uge, %2, %c2 : index + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @remsi_restricted +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @remsi_restricted(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c3 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.remsi %1, %c-4 : index + %3 = arith.cmpi ule, %2, %c3 : index + %4 = arith.cmpi uge, %2, %c2 : index + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @remui_restricted_fails +// CHECK: %[[ret:.*]] = arith.cmpi ne +// CHECK: return %[[ret]] +func @remui_restricted_fails(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + + %0 = arith.minui %arg0, %c5 : index + %1 = arith.maxui %0, %c3 : index + %2 = arith.remui %1, %c4 : index + %3 = arith.cmpi ne, %2, %c2 : index + return %3 : i1 +} + +// CHECK-LABEL: func @remsi_restricted_fails +// CHECK: %[[ret:.*]] = arith.cmpi ne +// CHECK: return %[[ret]] +func @remsi_restricted_fails(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c5 = arith.constant 5 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c5 : index + %1 = arith.maxsi %0, %c3 : index + %2 = arith.remsi %1, %c-4 : index + %3 = arith.cmpi ne, %2, %c2 : index + return %3 : i1 +} + +// CHECK-LABEL: func @andi +// CHECK: %[[ret:.*]] = arith.cmpi ugt +// CHECK: return %[[ret]] +func @andi(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c5 = arith.constant 5 : index + %c7 = arith.constant 7 : index + + %0 = arith.minsi %arg0, %c5 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.andi %1, %c7 : index + %3 = arith.cmpi ugt, %2, %c5 : index + %4 = arith.cmpi ule, %2, %c7 : index + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @andi_doesnt_make_nonnegative +// CHECK: %[[ret:.*]] = arith.cmpi sge +// CHECK: return %[[ret]] +func @andi_doesnt_make_nonnegative(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = arith.addi %arg0, %c1 : index + %1 = arith.andi %arg0, %0 : index + %2 = arith.cmpi sge, %1, %c0 : index + return %2 : i1 +} + + +// CHECK-LABEL: func @ori +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @ori(%arg0 : i128, %arg1 : i128) -> i1 { + %c-1 = arith.constant -1 : i128 + %c0 = arith.constant 0 : i128 + + %0 = arith.minsi %arg1, %c-1 : i128 + %1 = arith.ori %arg0, %0 : i128 + %2 = arith.cmpi slt, %1, %c0 : i128 + return %2 : i1 +} + +// CHECK-LABEL: func @xori +// CHECK: %[[false:.*]] = arith.constant false +// CHECK: return %[[false]] +func @xori(%arg0 : i64, %arg1 : i64) -> i1 { + %c0 = arith.constant 0 : i64 + %c7 = arith.constant 7 : i64 + %c15 = arith.constant 15 : i64 + %true = arith.constant true + + %0 = arith.minui %arg0, %c7 : i64 + %1 = arith.minui %arg1, %c15 : i64 + %2 = arith.xori %0, %1 : i64 + %3 = arith.cmpi sle, %2, %c15 : i64 + %4 = arith.xori %3, %true : i1 + return %4 : i1 +} + +// CHECK-LABEL: func @extui +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @extui(%arg0 : i16) -> i1 { + %ci16_max = arith.constant 0xffff : i32 + %0 = arith.extui %arg0 : i16 to i32 + %1 = arith.cmpi ule, %0, %ci16_max : i32 + return %1 : i1 +} + +// CHECK-LABEL: func @extsi +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @extsi(%arg0 : i16) -> i1 { + %ci16_smax = arith.constant 0x7fff : i32 + %ci16_smin = arith.constant 0xffff8000 : i32 + %0 = arith.extsi %arg0 : i16 to i32 + %1 = arith.cmpi sle, %0, %ci16_smax : i32 + %2 = arith.cmpi sge, %0, %ci16_smin : i32 + %3 = arith.andi %1, %2 : i1 + return %3 : i1 +} + +// CHECK-LABEL: func @trunci +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @trunci(%arg0 : i32) -> i1 { + %c-14_i32 = arith.constant -14 : i32 + %c-14_i16 = arith.constant -14 : i16 + %ci16_smin = arith.constant 0xffff8000 : i32 + %0 = arith.minsi %arg0, %c-14_i32 : i32 + %1 = arith.trunci %0 : i32 to i16 + %2 = arith.cmpi sle, %1, %c-14_i16 : i16 + %3 = arith.extsi %1 : i16 to i32 + %4 = arith.cmpi sle, %3, %c-14_i32 : i32 + %5 = arith.cmpi sge, %3, %ci16_smin : i32 + %6 = arith.andi %2, %4 : i1 + %7 = arith.andi %6, %5 : i1 + return %7 : i1 +} + +// CHECK-LABEL: func @index_cast +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @index_cast(%arg0 : index) -> i1 { + %ci32_smin = arith.constant 0xffffffff80000000 : i64 + %0 = arith.index_cast %arg0 : index to i32 + %1 = arith.index_cast %0 : i32 to index + %2 = arith.index_cast %ci32_smin : i64 to index + %3 = arith.cmpi sge, %1, %2 : index + return %3 : i1 +} + +// CHECK-LABEL: func @shli +// CHECK: %[[ret:.*]] = arith.cmpi sgt +// CHECK: return %[[ret]] +func @shli(%arg0 : i32, %arg1 : i1) -> i1 { + %c2 = arith.constant 2 : i32 + %c4 = arith.constant 4 : i32 + %c8 = arith.constant 8 : i32 + %c32 = arith.constant 32 : i32 + %c-1 = arith.constant -1 : i32 + %c-16 = arith.constant -16 : i32 + %0 = arith.maxsi %arg0, %c-1 : i32 + %1 = arith.minsi %0, %c2 : i32 + %2 = arith.select %arg1, %c2, %c4 : i32 + %3 = arith.shli %1, %2 : i32 + %4 = arith.cmpi sge, %3, %c-16 : i32 + %5 = arith.cmpi sle, %3, %c32 : i32 + %6 = arith.cmpi sgt, %3, %c8 : i32 + %7 = arith.andi %4, %5 : i1 + %8 = arith.andi %7, %6 : i1 + return %8 : i1 +} + +// CHECK-LABEL: func @shrui +// CHECK: %[[ret:.*]] = arith.cmpi uge +// CHECK: return %[[ret]] +func @shrui(%arg0 : i1) -> i1 { + %c2 = arith.constant 2 : i32 + %c4 = arith.constant 4 : i32 + %c8 = arith.constant 8 : i32 + %c32 = arith.constant 32 : i32 + %0 = arith.select %arg0, %c2, %c4 : i32 + %1 = arith.shrui %c32, %0 : i32 + %2 = arith.cmpi ule, %1, %c8 : i32 + %3 = arith.cmpi uge, %1, %c2 : i32 + %4 = arith.cmpi uge, %1, %c8 : i32 + %5 = arith.andi %2, %3 : i1 + %6 = arith.andi %5, %4 : i1 + return %6 : i1 +} + +// CHECK-LABEL: func @shrsi +// CHECK: %[[ret:.*]] = arith.cmpi slt +// CHECK: return %[[ret]] +func @shrsi(%arg0 : i32, %arg1 : i1) -> i1 { + %c2 = arith.constant 2 : i32 + %c4 = arith.constant 4 : i32 + %c8 = arith.constant 8 : i32 + %c32 = arith.constant 32 : i32 + %c-8 = arith.constant -8 : i32 + %c-32 = arith.constant -32 : i32 + %0 = arith.maxsi %arg0, %c-32 : i32 + %1 = arith.minsi %0, %c32 : i32 + %2 = arith.select %arg1, %c2, %c4 : i32 + %3 = arith.shrsi %1, %2 : i32 + %4 = arith.cmpi sge, %3, %c-8 : i32 + %5 = arith.cmpi sle, %3, %c8 : i32 + %6 = arith.cmpi slt, %3, %c2 : i32 + %7 = arith.andi %4, %5 : i1 + %8 = arith.andi %7, %6 : i1 + return %8 : i1 +} + +// CHECK-LABEL: func @no_aggressive_eq +// CHECK: %[[ret:.*]] = arith.cmpi eq +// CHECK: return %[[ret]] +func @no_aggressive_eq(%arg0 : index) -> i1 { + %c1 = arith.constant 1 : index + %0 = arith.andi %arg0, %c1 : index + %1 = arith.minui %arg0, %c1 : index + %2 = arith.cmpi eq, %0, %1 : index + return %2 : i1 +} + +// CHECK-LABEL: func @select_union +// CHECK: %[[ret:.*]] = arith.cmpi ne +// CHECK: return %[[ret]] + +func @select_union(%arg0 : index, %arg1 : i1) -> i1 { + %c64 = arith.constant 64 : index + %c100 = arith.constant 100 : index + %c128 = arith.constant 128 : index + %c192 = arith.constant 192 : index + %0 = arith.remui %arg0, %c64 : index + %1 = arith.addi %0, %c128 : index + %2 = arith.select %arg1, %0, %1 : index + %3 = arith.cmpi slt, %2, %c192 : index + %4 = arith.cmpi ne, %c100, %2 : index + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @if_union +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @if_union(%arg0 : index, %arg1 : i1) -> i1 { + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %c-1 = arith.constant -1 : index + %c-4 = arith.constant -4 : index + %0 = arith.minui %arg0, %c4 : index + %1 = scf.if %arg1 -> index { + %10 = arith.muli %0, %0 : index + scf.yield %10 : index + } else { + %20 = arith.muli %0, %c-1 : index + scf.yield %20 : index + } + %2 = arith.cmpi sle, %1, %c16 : index + %3 = arith.cmpi sge, %1, %c-4 : index + %4 = arith.andi %2, %3 : i1 + return %4 : i1 +} + +// CHECK-LABEL: func @branch_union +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @branch_union(%arg0 : index, %arg1 : i1) -> i1 { + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %c-1 = arith.constant -1 : index + %c-4 = arith.constant -4 : index + %0 = arith.minui %arg0, %c4 : index + cf.cond_br %arg1, ^bb1, ^bb2 +^bb1 : + %1 = arith.muli %0, %0 : index + cf.br ^bb3(%1 : index) +^bb2 : + %2 = arith.muli %0, %c-1 : index + cf.br ^bb3(%2 : index) +^bb3(%3 : index) : + %4 = arith.cmpi sle, %3, %c16 : index + %5 = arith.cmpi sge, %3, %c-4 : index + %6 = arith.andi %4, %5 : i1 + return %6 : i1 +} + +// CHECK-LABEL: func @loop_bound_not_inferred_with_branch +// CHECK-DAG: %[[min:.*]] = arith.cmpi sge +// CHECK-DAG: %[[max:.*]] = arith.cmpi slt +// CHECK-DAG: %[[ret:.*]] = arith.andi %[[min]], %[[max]] +// CHECK: return %[[ret]] +func @loop_bound_not_inferred_with_branch(%arg0 : index, %arg1 : i1) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %0 = arith.minui %arg0, %c4 : index + cf.br ^bb2(%c0 : index) +^bb1(%1 : index) : + %2 = arith.addi %1, %c1 : index + cf.br ^bb2(%2 : index) +^bb2(%3 : index): + %4 = arith.cmpi ult, %3, %c4 : index + cf.cond_br %4, ^bb1(%3 : index), ^bb3(%3 : index) +^bb3(%5 : index) : + %6 = arith.cmpi sge, %5, %c0 : index + %7 = arith.cmpi slt, %5, %c4 : index + %8 = arith.andi %6, %7 : i1 + return %8 : i1 +} +