diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -10,6 +10,7 @@ #define ARITHMETIC_OPS include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" +include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -49,7 +50,8 @@ // Base class for integer binary operations. class Arith_IntBinaryOp traits = []> : - Arith_BinaryOp, + Arith_BinaryOp]>, Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>, Results<(outs SignlessIntegerLike:$result)>; @@ -87,7 +89,8 @@ // Cast from an integer type to another integer type. class Arith_IToICastOp traits = []> : Arith_CastOp; + SignlessFixedWidthIntegerLike, + traits # [DeclareOpInterfaceMethods]>; // Cast from an integer type to a floating point type. class Arith_IToFCastOp traits = []> : Arith_CastOp; @@ -104,7 +107,8 @@ class Arith_CompareOp traits = []> : Arith_Op]> { + "lhs", "result", "::getI1SameShape($_self)">, + DeclareOpInterfaceMethods]> { let results = (outs BoolLike:$result); let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; @@ -124,6 +128,7 @@ def Arith_ConstantOp : Op, + DeclareOpInterfaceMethods, TypesMatchWith< "result and attribute have the same type", "value", "result", "$_self">]> { @@ -973,7 +978,8 @@ "signless-integer-like or memref of signless-integer">; def Arith_IndexCastOp : Arith_CastOp<"index_cast", IndexCastTypeConstraint, - IndexCastTypeConstraint> { + IndexCastTypeConstraint, + [DeclareOpInterfaceMethods]> { let summary = "cast between index and integer types"; let description = [{ Casts between scalar or vector integers and corresponding 'index' scalar or @@ -1166,7 +1172,8 @@ //===----------------------------------------------------------------------===// def SelectOp : Arith_Op<"select", [ - AllTypesMatch<["true_value", "false_value", "result"]> + AllTypesMatch<["true_value", "false_value", "result"]>, + DeclareOpInterfaceMethods, ] # ElementwiseMappable.traits> { let summary = "select operation"; let description = [{ @@ -1206,7 +1213,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; - + // FIXME: Switch this to use the declarative assembly format. let hasCustomAssemblyFormat = 1; } diff --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt @@ -6,6 +6,7 @@ ArithmeticOps.cpp ArithmeticDialect.cpp InferIntRangeInterface.cpp + InferIntRangeInterfaceImpls.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic diff --git a/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp @@ -0,0 +1,871 @@ +//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "arith" + +using namespace mlir; +using namespace mlir::arith; + +/// Get the bitwidth of the attributes holding constants of type `type`. +static unsigned int getAttrBitwidth(Type type) { + if (type.isIndex()) + return IndexType::kInternalStorageBitWidth; + return type.getIntOrFloatBitWidth(); +} + +/// Function that evaluates the result of doing something on arithmetic +/// constants and returns None on overflow. +using ConstArithFn = + llvm::function_ref(const APInt &, const APInt &)>; + +/// A [min, max] pair that can be signed or unsigned +using IntAttrPair = std::pair; + +/// If both `left` and `right` are defined, return the result of +/// `op(left.getValue(), right.getValue()`, where None is converted +/// to a null IntegerAttr. Otherwise, return the null attribute. +static IntegerAttr compute(ConstArithFn op, IntegerAttr left, + IntegerAttr right) { + if (!left || !right) + return {}; + assert(left.getType() == right.getType() && + "Arithmetic ops don't have mismatched operands"); + llvm::Optional result = op(left.getValue(), right.getValue()); + if (!result.hasValue()) + return {}; + return IntegerAttr::get(left.getType(), *result); +} + +/// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, +/// If either computation overflows, make the result unbounded. +static IntAttrPair computeBoundsBy(ConstArithFn op, IntegerAttr minLeft, + IntegerAttr minRight, IntegerAttr maxLeft, + IntegerAttr maxRight) { + IntegerAttr min, max; + if (minLeft && minRight) { + min = compute(op, minLeft, minRight); + if (!min) + return {{}, {}}; + } + if (maxLeft && maxRight) { + max = compute(op, maxLeft, maxRight); + if (!max) + return {{}, {}}; + } + return {min, max}; +} + +/// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, +/// ignoring unbounded values. Returns (null, null) if `op` overflows. +static IntAttrPair minMaxBy(ConstArithFn op, ArrayRef lhs, + ArrayRef rhs, bool signedCmp = false) { + IntegerAttr min, max; + for (IntegerAttr left : lhs) { + for (IntegerAttr right : rhs) { + if (!left || !right) { + // A missing lower or upper bound should be accounted for by the parent + // function + continue; + } + IntegerAttr thisResult = compute(op, left, right); + if (!thisResult) { + return {{}, {}}; + } + APInt thisValue = thisResult.getValue(); + if (min) { + min = (signedCmp ? thisValue.slt(min.getValue()) + : thisValue.ult(min.getValue())) + ? thisResult + : min; + } else { + min = thisResult; + } + + if (max) { + max = (signedCmp ? thisValue.sgt(max.getValue()) + : thisValue.ugt(max.getValue())) + ? thisResult + : max; + } else { + max = thisResult; + } + } + } + return {min, max}; +} + +void arith::ConstantOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + // Return null for non-scalar integer constants + auto value = getValue().dyn_cast_or_null(); + resultRanges.push_back(IntRangeAttrs::range(value, value)); +} + +void arith::AddIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + ConstArithFn uadd = [](auto &a, auto &b) -> llvm::Optional { + bool overflowed = false; + APInt result = a.uadd_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + ConstArithFn sadd = [](auto &a, auto &b) -> llvm::Optional { + bool overflowed = false; + APInt result = a.sadd_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + + auto urange = + computeBoundsBy(uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax()); + auto srange = + computeBoundsBy(sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax()); + resultRanges.emplace_back(urange, srange); +} + +void arith::SubIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + ConstArithFn usub = [](auto &a, auto &b) -> llvm::Optional { + bool overflowed = false; + APInt result = a.usub_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + ConstArithFn ssub = [](auto &a, auto &b) -> llvm::Optional { + bool overflowed = false; + APInt result = a.ssub_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + auto urange = + computeBoundsBy(usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin()); + auto srange = + computeBoundsBy(ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin()); + resultRanges.emplace_back(urange, srange); +} + +void arith::MulIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + // Determine what bounds we can impose on signed multiplication. + bool noNegatives = + (lhs.smin() && rhs.smin() && lhs.smin().getValue().isNonNegative() && + rhs.smin().getValue().isNonNegative()); + bool canBoundBelow = + (lhs.smin() && rhs.smin() && (noNegatives || (lhs.smax() && rhs.smax()))); + bool canBoundAbove = + (lhs.smax() && rhs.smax() && (noNegatives || (lhs.smin() && rhs.smin()))); + + ConstArithFn umul = [](auto &a, auto &b) -> llvm::Optional { + bool overflowed = false; + APInt result = a.umul_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + ConstArithFn smul = [](auto &a, auto &b) -> llvm::Optional { + bool overflowed = false; + APInt result = a.smul_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + + IntegerAttr umin, umax, smin, smax; + std::tie(umin, umax) = + minMaxBy(umul, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*signedCmp=*/false); + std::tie(smin, smax) = + minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()}, + /*signedCmp=*/true); + + if (!lhs.umin() || !rhs.umin()) + umin = {}; + if (!lhs.umax() || !rhs.umax()) + umax = {}; + if (!canBoundBelow) + smin = {}; + if (!canBoundAbove) + smax = {}; + resultRanges.emplace_back(umin, umax, smin, smax); +} + +/// Fix up division results (ex. for ceiling and floor), returning an APInt +/// if there has been no overflow +using DivisionFixupFn = llvm::function_ref( + const APInt &lhs, const APInt &rhs, const APInt &result)>; + +static IntRangeAttrs inferDivUIRange(Type resultType, const IntRangeAttrs &lhs, + const IntRangeAttrs &rhs, + DivisionFixupFn fixup) { + IntegerAttr lhsMin = lhs.umin(); + IntegerAttr lhsMax = lhs.umax(); + IntegerAttr rhsMin = rhs.umin(); + IntegerAttr rhsMax = rhs.umax(); + + unsigned int bitwidth = getAttrBitwidth(resultType); + if (rhsMin && !rhsMin.getValue().isZero()) { + if (!rhsMax) // Bound divisor above by 0xffff...fff to get lower bound of 0 + rhsMax = IntegerAttr::get(resultType, APInt::getAllOnesValue(bitwidth)); + auto udiv = [&fixup](auto &a, auto &b) -> llvm::Optional { + return fixup(a, b, a.udiv(b)); + }; + auto urange = minMaxBy(udiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*signedCmp=*/false); + if (!lhsMin) + urange.first = {}; + if (!lhsMax) + urange.second = {}; + + return IntRangeAttrs::fromUnsigned(urange); + } + // Otherwise, it's possible we might divide by 0. + return {}; +} + +void arith::DivUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type resultType = getResult().getType(); + resultRanges.push_back(inferDivUIRange( + resultType, argRanges[0], argRanges[1], + [](auto &lhs, auto &rhs, auto &result) { return result; })); +} + +static IntRangeAttrs inferDivSIRange(Type resultType, const IntRangeAttrs &lhs, + const IntRangeAttrs &rhs, + DivisionFixupFn fixup) { + IntegerAttr lhsMin = lhs.smin(); + IntegerAttr lhsMax = lhs.smax(); + IntegerAttr rhsMin = rhs.smin(); + IntegerAttr rhsMax = rhs.smax(); + bool canBoundBelow = rhsMin && rhsMin.getValue().isStrictlyPositive(); + bool canBoundAbove = rhsMax && rhsMax.getValue().isNegative(); + bool canDivide = canBoundBelow || canBoundAbove; + + if (canDivide) { + unsigned int bitwidth = getAttrBitwidth(resultType); + // Unbounded below + negative upper bound -> lower bound = INT_MIN + if (!rhsMin) + rhsMin = IntegerAttr::get(resultType, APInt::getSignedMinValue(bitwidth)); + // Unbounded above + positive lower bound -> upper bound = INT_MAX + if (!rhsMax) + rhsMax = IntegerAttr::get(resultType, APInt::getSignedMaxValue(bitwidth)); + auto sdiv = [&fixup](auto &a, auto &b) -> llvm::Optional { + bool overflowed = false; + APInt result = a.sdiv_ov(b, overflowed); + if (overflowed) + return {}; + return fixup(a, b, result); + }; + auto srange = minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*signedCmp=*/true); + if (!lhsMin) + srange.first = {}; + if (!lhsMax) + srange.second = {}; + + return IntRangeAttrs::fromSigned(srange); + } + return {}; +} + +void arith::DivSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + resultRanges.push_back(inferDivSIRange( + getResult().getType(), argRanges[0], argRanges[1], + [](auto &lhs, auto &rhs, auto &result) { return result; })); +} + +void arith::CeilDivUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + DivisionFixupFn ceilDivUIFix = [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> Optional { + if (!lhs.urem(rhs).isZero()) { + bool overflowed = false; + APInt corrected = + result.uadd_ov(APInt(result.getBitWidth(), 1), overflowed); + if (overflowed) + return {}; + return corrected; + } + return result; + }; + Type resultType = getResult().getType(); + resultRanges.push_back(inferDivUIRange(resultType, lhs, rhs, ceilDivUIFix)); +} + +void arith::CeilDivSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + Type resultType = getResult().getType(); + DivisionFixupFn ceilDivSIFix = [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> Optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() == rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.sadd_ov(APInt(result.getBitWidth(), 1), overflowed); + if (overflowed) + return {}; + return corrected; + } + return result; + }; + resultRanges.push_back(inferDivSIRange(resultType, lhs, rhs, ceilDivSIFix)); +} + +void arith::FloorDivSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + Type resultType = getResult().getType(); + DivisionFixupFn floorDivSIFix = [](const APInt &lhs, const APInt &rhs, + const APInt &result) -> Optional { + if (!lhs.srem(rhs).isZero() && lhs.isNonNegative() != rhs.isNonNegative()) { + bool overflowed = false; + APInt corrected = + result.ssub_ov(APInt(result.getBitWidth(), 1), overflowed); + if (overflowed) + return {}; + return corrected; + } + return result; + }; + resultRanges.push_back(inferDivSIRange(resultType, lhs, rhs, floorDivSIFix)); +} + +void arith::RemUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + IntegerAttr rhsMin = rhs.umin(); + IntegerAttr rhsMax = rhs.umax(); + IntegerAttr umin, umax; + Type resultType = getResult().getType(); + + if (rhsMin && rhsMax && !rhsMin.getValue().isZero()) { + APInt maxDivisor = rhsMax.getValue(); + umin = + IntegerAttr::get(resultType, APInt::getZero(maxDivisor.getBitWidth())); + umax = IntegerAttr::get(resultType, maxDivisor - 1); + // Special case: sweeping out a contiguous range in N/[modulus] + IntegerAttr lhsMin = lhs.umin(); + IntegerAttr lhsMax = lhs.umax(); + if (lhsMin && lhsMax && rhsMin == rhsMax) { + APInt minDividend = lhsMin.getValue(); + APInt maxDividend = lhsMax.getValue(); + if ((maxDividend - minDividend).ult(maxDivisor)) { + APInt minRem = minDividend.urem(maxDivisor); + APInt maxRem = maxDividend.urem(maxDivisor); + if (minRem.ule(maxRem)) { + umin = IntegerAttr::get(resultType, minRem); + umax = IntegerAttr::get(resultType, maxRem); + } + } + } + } + resultRanges.push_back(IntRangeAttrs::fromUnsigned(umin, umax)); +} + +void arith::RemSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + IntegerAttr lhsMin = lhs.smin(); + IntegerAttr lhsMax = lhs.smax(); + IntegerAttr rhsMin = rhs.smin(); + IntegerAttr rhsMax = rhs.smax(); + + Type resultType = getResult().getType(); + IntegerAttr smin, smax; + // No bounds if zero could be a divisor. + bool canBound = rhsMax && rhsMin && + (rhsMin.getValue().isStrictlyPositive() || + rhsMax.getValue().isNegative()); + if (canBound) { + APInt maxDivisor = rhsMin.getValue().isStrictlyPositive() + ? rhsMax.getValue() + : rhsMin.getValue().abs(); + bool canNegativeDividend = !(lhsMin && lhsMin.getValue().isNonNegative()); + bool canPositiveDividend = !(lhsMax && lhsMax.getValue().isNonPositive()); + APInt zero = APInt::getZero(maxDivisor.getBitWidth()); + APInt maxPositiveResult = maxDivisor - 1; + APInt minNegativeResult = -maxPositiveResult; + smin = IntegerAttr::get(resultType, + canNegativeDividend ? minNegativeResult : zero); + smax = IntegerAttr::get(resultType, + canPositiveDividend ? maxPositiveResult : zero); + // Special case: sweeping out a contiguous range in N/[modulus] + if (lhsMin && lhsMax && rhsMin == rhsMax) { + APInt minDividend = lhsMin.getValue(); + APInt maxDividend = lhsMax.getValue(); + if ((maxDividend - minDividend).ult(maxDivisor)) { + APInt minRem = minDividend.srem(maxDivisor); + APInt maxRem = maxDividend.srem(maxDivisor); + if (minRem.sle(maxRem)) { + smin = IntegerAttr::get(resultType, minRem); + smax = IntegerAttr::get(resultType, maxRem); + } + } + } + } + resultRanges.push_back(IntRangeAttrs::fromSigned(smin, smax)); +} + +/// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, +/// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits +/// that both bonuds have in common. This gives us a consertive approximation +/// for what values can be passed to bitwise operations. This will widen missing +/// bounds to all zeroes / all ones so we can handle [unbounded] & 0xff => [0, +/// 255]. +static IntAttrPair widenBitwiseBounds(Type resultType, + const IntRangeAttrs &bound) { + unsigned int bitwidth = getAttrBitwidth(resultType); + APInt leftVal = + bound.umin() ? bound.umin().getValue() : APInt::getZero(bitwidth); + APInt rightVal = + bound.umax() ? bound.umax().getValue() : APInt::getAllOnesValue(bitwidth); + unsigned int differingBits = + bitwidth - (leftVal ^ rightVal).countLeadingZeros(); + leftVal.clearLowBits(differingBits); + rightVal.setLowBits(differingBits); + IntegerAttr zeroes = IntegerAttr::get(resultType, leftVal); + IntegerAttr ones = IntegerAttr::get(resultType, rightVal); + return {zeroes, ones}; +} + +void arith::AndIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type resultType = getResult().getType(); + IntegerAttr lhsZeros, lhsOnes, rhsZeros, rhsOnes; + std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(resultType, argRanges[0]); + std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(resultType, argRanges[1]); + resultRanges.push_back(IntRangeAttrs::fromUnsigned( + minMaxBy([](auto &a, auto &b) -> llvm::Optional { return a & b; }, + {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, /*signedCmp=*/false))); +} + +void arith::OrIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type resultType = getResult().getType(); + IntegerAttr lhsZeros, lhsOnes, rhsZeros, rhsOnes; + std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(resultType, argRanges[0]); + std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(resultType, argRanges[1]); + resultRanges.push_back(IntRangeAttrs::fromUnsigned( + minMaxBy([](auto &a, auto &b) -> llvm::Optional { return a | b; }, + {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, /*signedCmp=*/false))); +} + +void arith::XOrIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type resultType = getResult().getType(); + IntegerAttr lhsZeros, lhsOnes, rhsZeros, rhsOnes; + std::tie(lhsZeros, lhsOnes) = widenBitwiseBounds(resultType, argRanges[0]); + std::tie(rhsZeros, rhsOnes) = widenBitwiseBounds(resultType, argRanges[1]); + resultRanges.push_back(IntRangeAttrs::fromUnsigned( + minMaxBy([](auto &a, auto &b) -> llvm::Optional { return a ^ b; }, + {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, /*signedCmp=*/false))); +} + +void arith::MaxSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + IntegerAttr smin; + // Take the largest lower bound (if any). + if (lhs.smin() && rhs.smin()) + smin = lhs.smin().getValue().sgt(rhs.smin().getValue()) ? lhs.smin() + : rhs.smin(); + else if (lhs.smin()) + smin = lhs.smin(); + else if (rhs.smin()) + smin = rhs.smin(); + + // If both upper bounds are present, take their max, be unbounded otherwise. + IntegerAttr smax; + if (lhs.smax() && rhs.smax()) + smax = lhs.smax().getValue().sgt(rhs.smax().getValue()) ? lhs.smax() + : rhs.smax(); + resultRanges.push_back(IntRangeAttrs::fromSigned(smin, smax)); +} + +void arith::MaxUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + IntegerAttr umin; + // Take the largest lower bound (if any). + if (lhs.umin() && rhs.umin()) + umin = lhs.umin().getValue().ugt(rhs.umin().getValue()) ? lhs.umin() + : rhs.umin(); + else if (lhs.umin()) + umin = lhs.umin(); + else if (rhs.umin()) + umin = rhs.umin(); + + // If both upper bounds are present, take their max, be unbounded otherwise. + IntegerAttr umax; + if (lhs.umax() && rhs.umax()) + umax = lhs.umax().getValue().ugt(rhs.umax().getValue()) ? lhs.umax() + : rhs.umax(); + resultRanges.push_back(IntRangeAttrs::fromUnsigned(umin, umax)); +} + +void arith::MinSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + IntegerAttr smin; + // If both lower bounds are present, take their minimum. + if (lhs.smin() && rhs.smin()) + smin = lhs.smin().getValue().slt(rhs.smin().getValue()) ? lhs.smin() + : rhs.smin(); + + // For upper bounds, take the smallest (with absent -> +infinity). + IntegerAttr smax; + if (lhs.smax() && rhs.smax()) + smax = lhs.smax().getValue().slt(rhs.smax().getValue()) ? lhs.smax() + : rhs.smax(); + else if (lhs.smax()) + smax = lhs.smax(); + else if (rhs.smax()) + smax = rhs.smax(); + resultRanges.push_back(IntRangeAttrs::fromSigned(smin, smax)); +} + +void arith::MinUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + IntegerAttr umin; + // If both lower bounds are present, take their minimum. + if (lhs.umin() && rhs.umin()) + umin = lhs.umin().getValue().ult(rhs.umin().getValue()) ? lhs.umin() + : rhs.umin(); + + // For upper bounds, take the smallest (with absent -> +infinity). + IntegerAttr umax; + if (lhs.umax() && rhs.umax()) + umax = lhs.umax().getValue().ult(rhs.umax().getValue()) ? lhs.umax() + : rhs.umax(); + else if (lhs.umax()) + umax = lhs.umax(); + else if (rhs.umax()) + umax = rhs.umax(); + resultRanges.push_back(IntRangeAttrs::fromUnsigned(umin, umax)); +} + +void arith::ExtUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + unsigned int srcWidth = sourceType.getIntOrFloatBitWidth(); + unsigned int destWidth = destType.getIntOrFloatBitWidth(); + IntegerAttr umin = argRanges[0].umin(); + IntegerAttr umax = argRanges[0].umax(); + + if (umin) + umin = IntegerAttr::get(destType, umin.getValue().zext(destWidth)); + else + umin = IntegerAttr::get(destType, APInt::getZero(destWidth)); + + if (umax) + umax = IntegerAttr::get(destType, umax.getValue().zext(destWidth)); + else + umax = + IntegerAttr::get(destType, APInt::getLowBitsSet(destWidth, srcWidth)); + + resultRanges.push_back(IntRangeAttrs::fromUnsigned(umin, umax)); +} + +static IntRangeAttrs extSIRange(const IntRangeAttrs &range, Type sourceType, + Type destType) { + unsigned int srcWidth = getAttrBitwidth(sourceType); + unsigned int destWidth = getAttrBitwidth(destType); + IntegerAttr smin = range.smin(); + IntegerAttr smax = range.smax(); + if (smin) + smin = IntegerAttr::get(destType, smin.getValue().sext(destWidth)); + else + smin = IntegerAttr::get( + destType, APInt::getHighBitsSet(destWidth, destWidth - srcWidth + 1)); + + if (smax) + smax = IntegerAttr::get(destType, smax.getValue().sext(destWidth)); + else + smax = IntegerAttr::get(destType, + APInt::getLowBitsSet(destWidth, srcWidth - 1)); + + return IntRangeAttrs::fromSigned(smin, smax); +} + +void arith::ExtSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + resultRanges.push_back(extSIRange(argRanges[0], sourceType, destType)); +} + +static IntRangeAttrs truncIRange(const IntRangeAttrs &range, Type destType) { + IntegerAttr umin = range.umin(); + IntegerAttr umax = range.umax(); + IntegerAttr smin = range.smin(); + IntegerAttr smax = range.smax(); + unsigned int destWidth = getAttrBitwidth(destType); + + if (umin) + umin = IntegerAttr::get(destType, umin.getValue().trunc(destWidth)); + if (umax) + umax = IntegerAttr::get(destType, umax.getValue().trunc(destWidth)); + if (smin) + smin = IntegerAttr::get(destType, smin.getValue().trunc(destWidth)); + if (smax) + smax = IntegerAttr::get(destType, smax.getValue().trunc(destWidth)); + + return {umin, umax, smin, smax}; +} + +void arith::TruncIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type destType = getResult().getType(); + resultRanges.push_back(truncIRange(argRanges[0], destType)); +} + +void arith::IndexCastOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + unsigned int srcWidth = getAttrBitwidth(sourceType); + unsigned int destWidth = getAttrBitwidth(destType); + + if (srcWidth < destWidth) + resultRanges.push_back(extSIRange(argRanges[0], sourceType, destType)); + else if (srcWidth > destWidth) + resultRanges.push_back(truncIRange(argRanges[0], destType)); + else + resultRanges.push_back(argRanges[0]); +} + +bool isStaticallyTrue(arith::CmpIPredicate pred, const IntRangeAttrs &lhs, + const IntRangeAttrs &rhs) { + switch (pred) { + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::slt: + return ( + lhs.smax() && rhs.smin() && + applyCmpPredicate(pred, lhs.smax().getValue(), rhs.smin().getValue())); + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::ult: + return ( + lhs.umax() && rhs.umin() && + applyCmpPredicate(pred, lhs.umax().getValue(), rhs.umin().getValue())); + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::sgt: + return ( + lhs.smin() && rhs.smax() && + applyCmpPredicate(pred, lhs.smin().getValue(), rhs.smax().getValue())); + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::ugt: + return ( + lhs.umin() && rhs.umax() && + applyCmpPredicate(pred, lhs.umin().getValue(), rhs.umax().getValue())); + case arith::CmpIPredicate::eq: { + Optional lhsConst = lhs.getConstantValue(); + Optional rhsConst = rhs.getConstantValue(); + return lhsConst.hasValue() && rhsConst.hasValue() && + (lhsConst->getValue() == rhsConst->getValue()); + } + case arith::CmpIPredicate::ne: { + // While equality requires that there is an interpration of the preceeding + // computations that produces equal constants, whether that be signed or + // unsigned, statically determining inequality requires that neither + // interpretation produce potentially overlapping ranges. + bool sne = isStaticallyTrue(CmpIPredicate::slt, lhs, rhs) || + isStaticallyTrue(CmpIPredicate::sgt, lhs, rhs); + bool une = isStaticallyTrue(CmpIPredicate::ult, lhs, rhs) || + isStaticallyTrue(CmpIPredicate::ugt, lhs, rhs); + return sne && une; + } + } + return false; +} + +void arith::CmpIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + arith::CmpIPredicate pred = getPredicate(); + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + + IntegerAttr value; + if (isStaticallyTrue(pred, lhs, rhs)) + value = IntegerAttr::get(getResult().getType(), 1); + else if (isStaticallyTrue(invertPredicate(pred), lhs, rhs)) + value = IntegerAttr::get(getResult().getType(), 0); + + resultRanges.push_back(IntRangeAttrs::range(value, value)); +} + +void arith::CmpFOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + // Can't infer anything about floats. + resultRanges.emplace_back(); +} + +void arith::SelectOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Optional mbCondVal = argRanges[0].getConstantValue(); + + if (mbCondVal.hasValue()) { + if (mbCondVal->getValue().isZero()) + resultRanges.push_back(argRanges[2]); + else + resultRanges.push_back(argRanges[1]); + return; + } + resultRanges.push_back(IntRangeAttrs::join(argRanges[1], argRanges[2])); +} + +void arith::ShLIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + if (!rhs.umin() || !rhs.umax()) { + resultRanges.emplace_back(); + return; + } + + ConstArithFn shl = [](auto &l, auto &r) -> Optional { + if (r.uge(r.getBitWidth())) + return {}; + return l.shl(r); + }; + IntegerAttr umin, umax, smin, smax; + std::tie(umin, umax) = + minMaxBy(shl, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*signedCmp=*/false); + std::tie(smin, smax) = + minMaxBy(shl, {lhs.smin(), lhs.smax()}, {rhs.umin(), rhs.umax()}, + /*signedCmp=*/true); + if (!lhs.umin()) + umin = {}; + if (!lhs.umax()) + umax = {}; + if (!lhs.smin()) + smin = {}; + if (!lhs.smax()) + smax = {}; + + resultRanges.emplace_back(umin, umax, smin, smax); +} + +void arith::ShRUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + if (!rhs.umin() || !rhs.umax()) { + resultRanges.emplace_back(); + return; + } + + ConstArithFn lshr = [](auto &l, auto &r) -> Optional { + if (r.uge(r.getBitWidth())) + return {}; + return l.lshr(r); + }; + IntegerAttr umin, umax; + std::tie(umin, umax) = + minMaxBy(lshr, {lhs.umin(), lhs.umax()}, {rhs.umin(), rhs.umax()}, + /*signedCmp=*/false); + if (!lhs.umin()) + umin = {}; + if (!lhs.umax()) + umax = {}; + + resultRanges.push_back(IntRangeAttrs::fromUnsigned(umin, umax)); +} + +void arith::ShRSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + const IntRangeAttrs &lhs = argRanges[0]; + const IntRangeAttrs &rhs = argRanges[1]; + if (!rhs.umin() || !rhs.umax()) { + resultRanges.emplace_back(); + return; + } + + ConstArithFn ashr = [](auto &l, auto &r) -> Optional { + if (r.uge(r.getBitWidth())) + return {}; + return l.ashr(r); + }; + IntegerAttr smin, smax; + std::tie(smin, smax) = minMaxBy(ashr, {lhs.smin(), lhs.smax()}, + {rhs.umin(), rhs.umax()}, /*signedCmp=*/true); + if (!lhs.smin()) + smin = {}; + if (!lhs.smax()) + smax = {}; + + resultRanges.push_back(IntRangeAttrs::fromSigned(smin, smax)); +} diff --git a/mlir/test/Dialect/Arithmetic/fold-inferred-constants.mlir b/mlir/test/Dialect/Arithmetic/fold-inferred-constants.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arithmetic/fold-inferred-constants.mlir @@ -0,0 +1,626 @@ +// RUN: mlir-opt -arith-fold-inferred-constants -canonicalize %s | FileCheck %s + +// CHECK-LABEL: func @add_min_max +// CHECK: %[[c3:.*]] = arith.constant 3 : index +// CHECK: return %[[c3]] +func @add_min_max(%a: index, %b: index) -> index { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = arith.minsi %a, %c1 : index + %1 = arith.maxsi %0, %c1 : index + %2 = arith.minui %b, %c2 : index + %3 = arith.maxui %2, %c2 : index + %4 = arith.addi %1, %3 : index + return %4 : index +} + +// CHECK-LABEL: func @add_lower_bound +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @add_lower_bound(%a : i32, %b : i32) -> i1 { + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %0 = arith.maxsi %a, %c1 : i32 + %1 = arith.maxsi %b, %c1 : i32 + %2 = arith.addi %0, %1 : i32 + %3 = arith.cmpi sge, %2, %c2 : i32 + %4 = arith.cmpi uge, %2, %c2 : i32 + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @sub_signed_vs_unsigned +// CHECK-NOT: arith.cmpi sle +// CHECK: %[[unsigned:.*]] = arith.cmpi ule +// CHECK: return %[[unsigned]] : i1 +func @sub_signed_vs_unsigned(%v : i64) -> i1 { + %c0 = arith.constant 0 : i64 + %c2 = arith.constant 2 : i64 + %0 = arith.minsi %v, %c2 : i64 + %1 = arith.subi %0, %c2 : i64 + %2 = arith.cmpi sle, %1, %c0 : i64 + %3 = arith.cmpi ule, %1, %c0 : i64 + %4 = arith.andi %2, %3 : i1 + return %4 : i1 +} + +// CHECK-LABEL: func @multiply_negatives +// CHECK: %[[false:.*]] = arith.constant false +// CHECK: return %[[false]] +func @multiply_negatives(%a : index, %b : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c_1 = arith.constant -1 : index + %c_2 = arith.constant -2 : index + %c_4 = arith.constant -4 : index + %c_12 = arith.constant -12 : index + %0 = arith.maxsi %a, %c2 : index + %1 = arith.minsi %0, %c3 : index + %2 = arith.minsi %b, %c_1 : index + %3 = arith.maxsi %2, %c_4 : index + %4 = arith.muli %1, %3 : index + %5 = arith.cmpi slt, %4, %c_12 : index + %6 = arith.cmpi slt, %c_1, %4 : index + %7 = arith.ori %5, %6 : i1 + return %7 : i1 +} + +// CHECK-LABEL: func @multiply_unsigned_bounds +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @multiply_unsigned_bounds(%a : i16, %b : i16) -> i1 { + %c0 = arith.constant 0 : i16 + %c4 = arith.constant 4 : i16 + %c_mask = arith.constant 0x3fff : i16 + %c_bound = arith.constant 0xfffc : i16 + %0 = arith.andi %a, %c_mask : i16 + %1 = arith.minui %b, %c4 : i16 + %2 = arith.muli %0, %1 : i16 + %3 = arith.cmpi uge, %2, %c0 : i16 + %4 = arith.cmpi ule, %2, %c_bound : i16 + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: @for_loop_with_increasing_arg +// CHECK: %[[ret:.*]] = arith.cmpi ule +// CHECK: return %[[ret]] +func @for_loop_with_increasing_arg() -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %0 = scf.for %arg0 = %c0 to %c4 step %c1 iter_args(%arg1 = %c0) -> index { + %10 = arith.addi %arg0, %arg1 : index + scf.yield %10 : index + } + %1 = arith.cmpi ule, %0, %c16 : index + return %1 : i1 +} + +// CHECK-LABEL: @for_loop_with_constant_result +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @for_loop_with_constant_result() -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %true = arith.constant true + %0 = scf.for %arg0 = %c0 to %c4 step %c1 iter_args(%arg1 = %true) -> i1 { + %10 = arith.cmpi ule, %arg0, %c4 : index + %11 = arith.andi %10, %arg1 : i1 + scf.yield %11 : i1 + } + return %0 : i1 +} + +// CHECK-LABEL: func @div_bounds_positive +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @div_bounds_positive(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %0 = arith.maxsi %arg0, %c2 : index + %1 = arith.divsi %c4, %0 : index + %2 = arith.divui %c4, %0 : index + + %3 = arith.cmpi sge, %1, %c0 : index + %4 = arith.cmpi sle, %1, %c2 : index + %5 = arith.cmpi sge, %2, %c0 : index + %6 = arith.cmpi sle, %1, %c2 : index + + %7 = arith.andi %3, %4 : i1 + %8 = arith.andi %7, %5 : i1 + %9 = arith.andi %8, %6 : i1 + return %9 : i1 +} + +// CHECK-LABEL: func @div_bounds_negative +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @div_bounds_negative(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c_2 = arith.constant -2 : index + %c4 = arith.constant 4 : index + %0 = arith.minsi %arg0, %c_2 : index + %1 = arith.divsi %c4, %0 : index + %2 = arith.divui %c4, %0 : index + + %3 = arith.cmpi sle, %1, %c0 : index + %4 = arith.cmpi sge, %1, %c_2 : index + %5 = arith.cmpi eq, %2, %c0 : index + + %7 = arith.andi %3, %4 : i1 + %8 = arith.andi %7, %5 : i1 + return %8 : i1 +} + +// CHECK-LABEL: func @div_zero_undefined +// CHECK: %[[ret:.*]] = arith.cmpi ule +// CHECK: return %[[ret]] +func @div_zero_undefined(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %0 = arith.andi %arg0, %c1 : index + %1 = arith.divui %c4, %0 : index + %2 = arith.cmpi ule, %1, %c4 : index + return %2 : i1 +} + +// CHECK-LABEL: func @ceil_divui +// CHECK: %[[ret:.*]] = arith.cmpi eq +// CHECK: return %[[ret]] +func @ceil_divui(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + + %0 = arith.minui %arg0, %c3 : index + %1 = arith.maxui %0, %c1 : index + %2 = arith.ceildivui %1, %c4 : index + %3 = arith.cmpi eq, %2, %c1 : index + + %4 = arith.maxui %0, %c0 : index + %5 = arith.ceildivui %4, %c4 : index + %6 = arith.cmpi eq, %5, %c1 : index + %7 = arith.andi %3, %6 : i1 + return %7 : i1 +} + +// CHECK-LABEL: func @ceil_divsi +// CHECK: %[[ret:.*]] = arith.cmpi eq +// CHECK: return %[[ret]] +func @ceil_divsi(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c3 : index + %1 = arith.maxsi %0, %c1 : index + %2 = arith.ceildivsi %1, %c4 : index + %3 = arith.cmpi eq, %2, %c1 : index + %4 = arith.ceildivsi %1, %c-4 : index + %5 = arith.cmpi eq, %4, %c0 : index + %6 = arith.andi %3, %5 : i1 + + %7 = arith.maxsi %0, %c0 : index + %8 = arith.ceildivsi %7, %c4 : index + %9 = arith.cmpi eq, %8, %c1 : index + %10 = arith.andi %6, %9 : i1 + return %10 : i1 +} + +// CHECK-LABEL: func @floor_divsi +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @floor_divsi(%arg0 : index) -> i1 { + %c4 = arith.constant 4 : index + %c-1 = arith.constant -1 : index + %c-3 = arith.constant -3 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c-1 : index + %1 = arith.maxsi %0, %c-4 : index + %2 = arith.floordivsi %1, %c4 : index + %3 = arith.cmpi eq, %2, %c-1 : index + return %3 : i1 +} + +// CHECK-LABEL: func @remui_base +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @remui_base(%arg0 : index, %arg1 : index ) -> i1 { + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + %0 = arith.minui %arg1, %c4 : index + %1 = arith.maxui %0, %c2 : index + %2 = arith.remui %arg0, %1 : index + %3 = arith.cmpi ult, %2, %c4 : index + return %3 : i1 +} + +// CHECK-LABEL: func @remsi_base +// CHECK: %[[ret:.*]] = arith.cmpi sge +// CHECK: return %[[ret]] +func @remsi_base(%arg0 : index, %arg1 : index ) -> i1 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %c-4 = arith.constant -4 : index + %true = arith.constant true + + %0 = arith.minsi %arg1, %c4 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.remsi %arg0, %1 : index + %3 = arith.cmpi sgt, %2, %c-4 : index + %4 = arith.cmpi slt, %2, %c4 : index + %5 = arith.cmpi sge, %2, %c0 : index + %6 = arith.andi %3, %4 : i1 + %7 = arith.andi %5, %6 : i1 + return %7 : i1 +} + +// CHECK-LABEL: func @remsi_positive +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @remsi_positive(%arg0 : index, %arg1 : index ) -> i1 { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %true = arith.constant true + + %0 = arith.minsi %arg1, %c4 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.maxsi %arg0, %c0 : index + %3 = arith.remsi %2, %1 : index + %4 = arith.cmpi sge, %3, %c0 : index + %5 = arith.cmpi slt, %3, %c4 : index + %6 = arith.andi %4, %5 : i1 + return %6 : i1 +} + +// CHECK-LABEL: func @remui_restricted +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @remui_restricted(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + + %0 = arith.minui %arg0, %c3 : index + %1 = arith.maxui %0, %c2 : index + %2 = arith.remui %1, %c4 : index + %3 = arith.cmpi ule, %2, %c3 : index + %4 = arith.cmpi uge, %2, %c2 : index + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @remsi_restricted +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @remsi_restricted(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c3 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.remsi %1, %c-4 : index + %3 = arith.cmpi ule, %2, %c3 : index + %4 = arith.cmpi uge, %2, %c2 : index + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @remui_restricted_fails +// CHECK: %[[ret:.*]] = arith.cmpi ne +// CHECK: return %[[ret]] +func @remui_restricted_fails(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c5 = arith.constant 5 : index + + %0 = arith.minui %arg0, %c5 : index + %1 = arith.maxui %0, %c3 : index + %2 = arith.remui %1, %c4 : index + %3 = arith.cmpi ne, %2, %c2 : index + return %3 : i1 +} + +// CHECK-LABEL: func @remsi_restricted_fails +// CHECK: %[[ret:.*]] = arith.cmpi ne +// CHECK: return %[[ret]] +func @remsi_restricted_fails(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c5 = arith.constant 5 : index + %c-4 = arith.constant -4 : index + + %0 = arith.minsi %arg0, %c5 : index + %1 = arith.maxsi %0, %c3 : index + %2 = arith.remsi %1, %c-4 : index + %3 = arith.cmpi ne, %2, %c2 : index + return %3 : i1 +} + +// CHECK-LABEL: func @andi +// CHECK: %[[ret:.*]] = arith.cmpi ugt +// CHECK: return %[[ret]] +func @andi(%arg0 : index) -> i1 { + %c2 = arith.constant 2 : index + %c5 = arith.constant 5 : index + %c7 = arith.constant 7 : index + + %0 = arith.minsi %arg0, %c5 : index + %1 = arith.maxsi %0, %c2 : index + %2 = arith.andi %1, %c7 : index + %3 = arith.cmpi ugt, %2, %c5 : index + %4 = arith.cmpi ule, %2, %c7 : index + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @andi_doesnt_make_nonnegative +// CHECK: %[[ret:.*]] = arith.cmpi sge +// CHECK: return %[[ret]] +func @andi_doesnt_make_nonnegative(%arg0 : index) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = arith.addi %arg0, %c1 : index + %1 = arith.andi %arg0, %0 : index + %2 = arith.cmpi sge, %1, %c0 : index + return %2 : i1 +} + + +// CHECK-LABEL: func @ori +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @ori(%arg0 : i128, %arg1 : i128) -> i1 { + %c-1 = arith.constant -1 : i128 + %c0 = arith.constant 0 : i128 + + %0 = arith.minsi %arg1, %c-1 : i128 + %1 = arith.ori %arg0, %0 : i128 + %2 = arith.cmpi slt, %1, %c0 : i128 + return %2 : i1 +} + +// CHECK-LABEL: func @xori +// CHECK: %[[false:.*]] = arith.constant false +// CHECK: return %[[false]] +func @xori(%arg0 : i64, %arg1 : i64) -> i1 { + %c0 = arith.constant 0 : i64 + %c7 = arith.constant 7 : i64 + %c15 = arith.constant 15 : i64 + %true = arith.constant true + + %0 = arith.minui %arg0, %c7 : i64 + %1 = arith.minui %arg1, %c15 : i64 + %2 = arith.xori %0, %1 : i64 + %3 = arith.cmpi sle, %2, %c15 : i64 + %4 = arith.xori %3, %true : i1 + return %4 : i1 +} + +// CHECK-LABEL: func @extui +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @extui(%arg0 : i16) -> i1 { + %ci16_max = arith.constant 0xffff : i32 + %0 = arith.extui %arg0 : i16 to i32 + %1 = arith.cmpi ule, %0, %ci16_max : i32 + return %1 : i1 +} + +// CHECK-LABEL: func @extsi +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @extsi(%arg0 : i16) -> i1 { + %ci16_smax = arith.constant 0x7fff : i32 + %ci16_smin = arith.constant 0xffff8000 : i32 + %0 = arith.extsi %arg0 : i16 to i32 + %1 = arith.cmpi sle, %0, %ci16_smax : i32 + %2 = arith.cmpi sge, %0, %ci16_smin : i32 + %3 = arith.andi %1, %2 : i1 + return %3 : i1 +} + +// CHECK-LABEL: func @trunci +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @trunci(%arg0 : i32) -> i1 { + %c-14_i32 = arith.constant -14 : i32 + %c-14_i16 = arith.constant -14 : i16 + %ci16_smin = arith.constant 0xffff8000 : i32 + %0 = arith.minsi %arg0, %c-14_i32 : i32 + %1 = arith.trunci %0 : i32 to i16 + %2 = arith.cmpi sle, %1, %c-14_i16 : i16 + %3 = arith.extsi %1 : i16 to i32 + %4 = arith.cmpi sle, %3, %c-14_i32 : i32 + %5 = arith.cmpi sge, %3, %ci16_smin : i32 + %6 = arith.andi %2, %4 : i1 + %7 = arith.andi %6, %5 : i1 + return %7 : i1 +} + +// CHECK-LABEL: func @index_cast +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @index_cast(%arg0 : index) -> i1 { + %ci32_smin = arith.constant 0xffffffff80000000 : i64 + %0 = arith.index_cast %arg0 : index to i32 + %1 = arith.index_cast %0 : i32 to index + %2 = arith.index_cast %ci32_smin : i64 to index + %3 = arith.cmpi sge, %1, %2 : index + return %3 : i1 +} + +// CHECK-LABEL: func @shli +// CHECK: %[[ret:.*]] = arith.cmpi sgt +// CHECK: return %[[ret]] +func @shli(%arg0 : i32, %arg1 : i1) -> i1 { + %c2 = arith.constant 2 : i32 + %c4 = arith.constant 4 : i32 + %c8 = arith.constant 8 : i32 + %c32 = arith.constant 32 : i32 + %c-1 = arith.constant -1 : i32 + %c-16 = arith.constant -16 : i32 + %0 = arith.maxsi %arg0, %c-1 : i32 + %1 = arith.minsi %0, %c2 : i32 + %2 = arith.select %arg1, %c2, %c4 : i32 + %3 = arith.shli %1, %2 : i32 + %4 = arith.cmpi sge, %3, %c-16 : i32 + %5 = arith.cmpi sle, %3, %c32 : i32 + %6 = arith.cmpi sgt, %3, %c8 : i32 + %7 = arith.andi %4, %5 : i1 + %8 = arith.andi %7, %6 : i1 + return %8 : i1 +} + +// CHECK-LABEL: func @shrui +// CHECK: %[[ret:.*]] = arith.cmpi uge +// CHECK: return %[[ret]] +func @shrui(%arg0 : i1) -> i1 { + %c2 = arith.constant 2 : i32 + %c4 = arith.constant 4 : i32 + %c8 = arith.constant 8 : i32 + %c32 = arith.constant 32 : i32 + %0 = arith.select %arg0, %c2, %c4 : i32 + %1 = arith.shrui %c32, %0 : i32 + %2 = arith.cmpi ule, %1, %c8 : i32 + %3 = arith.cmpi uge, %1, %c2 : i32 + %4 = arith.cmpi uge, %1, %c8 : i32 + %5 = arith.andi %2, %3 : i1 + %6 = arith.andi %5, %4 : i1 + return %6 : i1 +} + +// CHECK-LABEL: func @shrsi +// CHECK: %[[ret:.*]] = arith.cmpi slt +// CHECK: return %[[ret]] +func @shrsi(%arg0 : i32, %arg1 : i1) -> i1 { + %c2 = arith.constant 2 : i32 + %c4 = arith.constant 4 : i32 + %c8 = arith.constant 8 : i32 + %c32 = arith.constant 32 : i32 + %c-8 = arith.constant -8 : i32 + %c-32 = arith.constant -32 : i32 + %0 = arith.maxsi %arg0, %c-32 : i32 + %1 = arith.minsi %0, %c32 : i32 + %2 = arith.select %arg1, %c2, %c4 : i32 + %3 = arith.shrsi %1, %2 : i32 + %4 = arith.cmpi sge, %3, %c-8 : i32 + %5 = arith.cmpi sle, %3, %c8 : i32 + %6 = arith.cmpi slt, %3, %c2 : i32 + %7 = arith.andi %4, %5 : i1 + %8 = arith.andi %7, %6 : i1 + return %8 : i1 +} + +// CHECK-LABEL: func @no_aggressive_eq +// CHECK: %[[ret:.*]] = arith.cmpi eq +// CHECK: return %[[ret]] +func @no_aggressive_eq(%arg0 : index) -> i1 { + %c1 = arith.constant 1 : index + %0 = arith.andi %arg0, %c1 : index + %1 = arith.minui %arg0, %c1 : index + %2 = arith.cmpi eq, %0, %1 : index + return %2 : i1 +} + +// CHECK-LABEL: func @select_union +// CHECK: %[[ret:.*]] = arith.cmpi ne +// CHECK: return %[[ret]] + +func @select_union(%arg0 : index, %arg1 : i1) -> i1 { + %c64 = arith.constant 64 : index + %c100 = arith.constant 100 : index + %c128 = arith.constant 128 : index + %c192 = arith.constant 192 : index + %0 = arith.remui %arg0, %c64 : index + %1 = arith.addi %0, %c128 : index + %2 = arith.select %arg1, %0, %1 : index + %3 = arith.cmpi slt, %2, %c192 : index + %4 = arith.cmpi ne, %c100, %2 : index + %5 = arith.andi %3, %4 : i1 + return %5 : i1 +} + +// CHECK-LABEL: func @if_union +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @if_union(%arg0 : index, %arg1 : i1) -> i1 { + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %c-1 = arith.constant -1 : index + %c-4 = arith.constant -4 : index + %0 = arith.minui %arg0, %c4 : index + %1 = scf.if %arg1 -> index { + %10 = arith.muli %0, %0 : index + scf.yield %10 : index + } else { + %20 = arith.muli %0, %c-1 : index + scf.yield %20 : index + } + %2 = arith.cmpi sle, %1, %c16 : index + %3 = arith.cmpi sge, %1, %c-4 : index + %4 = arith.andi %2, %3 : i1 + return %4 : i1 +} + +// CHECK-LABEL: func @branch_union +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func @branch_union(%arg0 : index, %arg1 : i1) -> i1 { + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %c-1 = arith.constant -1 : index + %c-4 = arith.constant -4 : index + %0 = arith.minui %arg0, %c4 : index + cf.cond_br %arg1, ^bb1, ^bb2 +^bb1 : + %1 = arith.muli %0, %0 : index + cf.br ^bb3(%1 : index) +^bb2 : + %2 = arith.muli %0, %c-1 : index + cf.br ^bb3(%2 : index) +^bb3(%3 : index) : + %4 = arith.cmpi sle, %3, %c16 : index + %5 = arith.cmpi sge, %3, %c-4 : index + %6 = arith.andi %4, %5 : i1 + return %6 : i1 +} + +// CHECK-LABEL: func @loop_bound_not_inferred_with_branch +// CHECK-DAG: %[[min:.*]] = arith.cmpi sge +// CHECK-DAG: %[[max:.*]] = arith.cmpi slt +// CHECK-DAG: %[[ret:.*]] = arith.andi %[[min]], %[[max]] +// CHECK: return %[[ret]] +func @loop_bound_not_inferred_with_branch(%arg0 : index, %arg1 : i1) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %0 = arith.minui %arg0, %c4 : index + cf.br ^bb2(%c0 : index) +^bb1(%1 : index) : + %2 = arith.addi %1, %c1 : index + cf.br ^bb2(%2 : index) +^bb2(%3 : index): + %4 = arith.cmpi ult, %3, %c4 : index + cf.cond_br %4, ^bb1(%3 : index), ^bb3(%3 : index) +^bb3(%5 : index) : + %6 = arith.cmpi sge, %5, %c0 : index + %7 = arith.cmpi slt, %5, %c4 : index + %8 = arith.andi %6, %7 : i1 + return %8 : i1 +} +