diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h --- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h @@ -28,6 +28,12 @@ #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.h.inc" +//===----------------------------------------------------------------------===// +// Arithmetic Dialect Interfaces +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h" + //===----------------------------------------------------------------------===// // Arithmetic Dialect Operations //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -10,6 +10,7 @@ #define ARITHMETIC_OPS include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" +include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -49,7 +50,8 @@ // Base class for integer binary operations. class Arith_IntBinaryOp traits = []> : - Arith_BinaryOp, + Arith_BinaryOp]>, Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>, Results<(outs SignlessIntegerLike:$result)>; @@ -87,7 +89,8 @@ // Cast from an integer type to another integer type. class Arith_IToICastOp traits = []> : Arith_CastOp; + SignlessFixedWidthIntegerLike, + traits # [DeclareOpInterfaceMethods]>; // Cast from an integer type to a floating point type. class Arith_IToFCastOp traits = []> : Arith_CastOp; @@ -104,7 +107,8 @@ class Arith_CompareOp traits = []> : Arith_Op]> { + "lhs", "result", "::getI1SameShape($_self)">, + DeclareOpInterfaceMethods]> { let results = (outs BoolLike:$result); let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; @@ -124,6 +128,7 @@ def Arith_ConstantOp : Op, + DeclareOpInterfaceMethods, TypesMatchWith< "result and attribute have the same type", "value", "result", "$_self">]> { @@ -973,7 +978,8 @@ "signless-integer-like or memref of signless-integer">; def Arith_IndexCastOp : Arith_CastOp<"index_cast", IndexCastTypeConstraint, - IndexCastTypeConstraint> { + IndexCastTypeConstraint, + [DeclareOpInterfaceMethods]> { let summary = "cast between index and integer types"; let description = [{ Casts between scalar or vector integers and corresponding 'index' scalar or @@ -1166,7 +1172,8 @@ //===----------------------------------------------------------------------===// def SelectOp : Arith_Op<"select", [ - AllTypesMatch<["true_value", "false_value", "result"]> + AllTypesMatch<["true_value", "false_value", "result"]>, + DeclareOpInterfaceMethods, ] # ElementwiseMappable.traits> { let summary = "select operation"; let description = [{ @@ -1206,7 +1213,7 @@ let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; - + // FIXME: Switch this to use the declarative assembly format. let hasCustomAssemblyFormat = 1; } diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt @@ -3,3 +3,4 @@ mlir_tablegen(ArithmeticOpsEnums.cpp.inc -gen-enum-defs) add_mlir_dialect(ArithmeticOps arith) add_mlir_doc(ArithmeticOps ArithmeticOps Dialects/ -gen-dialect-doc) +add_mlir_interface(InferIntRangeInterface) diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h b/mlir/include/mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h @@ -0,0 +1,76 @@ +//===- InferIntRangeInterface.h - Infer Ranges Interfaces --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of the integer range inference interface +// defined in `InferIntRange.td` +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE_H +#define MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace arith { +/// A pair of attributes representing the minimum and maximum value of an +/// integer value. These bounds are inclusive in both ends, so the pair (4, 5) +/// represents the bounds 4 <= x <= 5, while (4, 4) is the constant value 4. +/// If either bound is unset, this is treated as negative or positive infinity, +/// respectively. +struct IntRangeAttrs : public std::pair { + using std::pair::pair; + + static IntRangeAttrs join(const IntRangeAttrs &a, const IntRangeAttrs &b) { + IntegerAttr aMin = a.first; + IntegerAttr bMin = b.first; + IntegerAttr aMax = a.second; + IntegerAttr bMax = b.second; + IntegerAttr min, max; + if (aMin && bMin) + min = aMin.getValue().slt(bMin.getValue()) ? aMin : bMin; + if (aMax && bMax) + max = aMax.getValue().sgt(bMax.getValue()) ? aMax : bMax; + return {min, max}; + } + + static IntRangeAttrs getPessimisticValueState(MLIRContext *context) { + return {{}, {}}; + } + + static IntRangeAttrs getPessimisticValueState(Value v) { + return getPessimisticValueState(v.getContext()); + } + + inline friend raw_ostream &operator<<(raw_ostream &os, + const IntRangeAttrs &range) { + os << "["; + if (range.first) + os << range.first.getValue(); + else + os << "-inf"; + os << ", "; + if (range.second) + os << range.second.getValue(); + else + os << "inf"; + return os << "]"; + } +}; +} // namespace arith +} // namespace mlir + +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h.inc" + +#endif // MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE_H diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.td b/mlir/include/mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.td @@ -0,0 +1,49 @@ +//===- InferIntRangeInterface.td - Integer Range Analysis interface -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for range analysis on scalar integers +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE +#define MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE + +include "mlir/IR/OpBase.td" + +def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> { + let description = [{ + Allows operations to participate in range analysis for scalar integer + values by providing a methods that allows them to specify lower and + upper bounds on their result(s) given lower and upper bounds on their + input(s) if known. + }]; + let cppNamespace = "::mlir::arith"; + + let methods = [ + InterfaceMethod<[{ + Infer the bounds on the results of this op given the bounds on its + arguments. For each result value, the function should place a pair + of `IntegerAttr`s, also known as an `IntRangeAttrs`, into + `resultRanges`. If a result does not have integer type, or if + bounds can not be inferred, the pair corresponding to it should be (null, null) + The returned attributes, if present, should have the same type + as the corresponding result. + + This function will only be called when at least one result of the op + is a scalar value + + `argRanges` contains one `IntRangeAttrs` for each argument to + the op in ODS order. Non-integer arguments have `(nullptr, nullptr)` + in their position. + }], + "void", "inferResultRanges", (ins + "::llvm::ArrayRef<::mlir::arith::IntRangeAttrs>":$argRanges, + "::llvm::SmallVectorImpl<::mlir::arith::IntRangeAttrs> &":$resultRanges) + >]; +} +#endif // MLIR_DIALECT_ARITHMETIC_IR_INFERINTRANGEINTERFACE diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -9,15 +9,23 @@ #include #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" -#include "llvm/ADT/SmallString.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "arith" using namespace mlir; using namespace mlir::arith; @@ -40,6 +48,154 @@ rhs.cast().getInt()); } +// Get the bitwidth of the attributes holding constants of type `type` +static unsigned int getAttrBitwidth(Type type) { + if (type.isIndex()) + return IndexType::kInternalStorageBitWidth; + return type.getIntOrFloatBitWidth(); +} + +// Function that evaluates the result of doing something on arithmetic constants +// and returns None on overflow +using ConstArithFn = + llvm::function_ref(const APInt &, const APInt &)>; + +// If both `left` and `right` are defined, return the result of +// `op(left.getValue(), right.getValue()`, where None is converted +// to a null IntegerAttr. Otherwise, return the null attribute +static IntegerAttr compute(ConstArithFn op, IntegerAttr left, + IntegerAttr right) { + assert(left.getType() == right.getType() && + "Arithmetic ops don't have mismatched operands"); + if (!left || !right) + return {}; + llvm::Optional result = op(left.getValue(), right.getValue()); + if (!result.hasValue()) + return {}; + return IntegerAttr::get(left.getType(), *result); +} + +// Compute op(minLeft, minRight) and op(maxLeft, maxRight) if possible, +// If either computation overflows, make the result unbounded +static IntRangeAttrs computeBoundsBy(ConstArithFn op, IntegerAttr minLeft, + IntegerAttr minRight, IntegerAttr maxLeft, + IntegerAttr maxRight) { + IntegerAttr min, max; + if (minLeft && minRight) { + min = compute(op, minLeft, minRight); + if (!min) + return {{}, {}}; + } + if (maxLeft && maxRight) { + max = compute(op, maxLeft, maxRight); + if (!max) + return {{}, {}}; + } + return {min, max}; +} + +// Compute the minimum and maximum of `(op(l, r) for l in lhs for r in rhs)`, +// ignoring null attributes. Returns (null, null) if `op` overflows +static IntRangeAttrs minMaxBy(ConstArithFn op, ArrayRef lhs, + ArrayRef rhs, bool signedCmp) { + IntegerAttr min, max; + for (IntegerAttr left : lhs) { + for (IntegerAttr right : rhs) { + if (!left || !right) { + // A missing lower or upper bound should be accounted for by the parent + // function + continue; + } + IntegerAttr thisResult = compute(op, left, right); + if (!thisResult) { + return {{}, {}}; + } + APInt thisValue = thisResult.getValue(); + if (min) + min = (signedCmp ? thisValue.slt(min.getValue()) + : thisValue.ult(min.getValue())) + ? thisResult + : min; + else + min = thisResult; + + if (max) + max = (signedCmp ? thisValue.sgt(max.getValue()) + : thisValue.ugt(max.getValue())) + ? thisResult + : max; + else + max = thisResult; + } + } + return {min, max}; +} + +/// Interperet the bounds [min, max] as bounds on an unsigned value +IntRangeAttrs asUnsigned(Type resultType, const IntRangeAttrs &bounds) { + unsigned int bitwidth = getAttrBitwidth(resultType); + IntegerAttr zero = IntegerAttr::get(resultType, APInt::getZero(bitwidth)); + IntegerAttr min, max; + std::tie(min, max) = bounds; + if (!min) { + // Special case: we know the value is negative, which means it's between + // signed_min and the lwer bound + if (max) { + APInt maxValue = max.getValue(); + if (maxValue.isNegative()) { + IntegerAttr signedMin = + IntegerAttr::get(resultType, APInt::getSignedMinValue(bitwidth)); + return {signedMin, max}; + } + } + // Otherwise, all we know is that the value is at least zero + return {zero, {}}; + } + APInt minValue = min.getValue(); + bool negativeMin = minValue.isNegative(); + if (negativeMin) { + // Special case: both bounds less than zero lets us draw actual conclusions + if (max) { + APInt maxValue = max.getValue(); + bool negativeMax = maxValue.isNegative(); + if (negativeMin && negativeMax) { + APInt umin = minValue.ule(maxValue) ? minValue : maxValue; + APInt umax = minValue.ugt(maxValue) ? minValue : maxValue; + return {IntegerAttr::get(resultType, umin), + IntegerAttr::get(resultType, umax)}; + } + } + // Otherwise, either there is no upper bound or zero is in the range, + // and so we can draw no conclusions + return {zero, {}}; + } + return {min, max}; +} + +/// Interperet the bounds [min, max] as bounds on a signed value +IntRangeAttrs asSigned(const IntRangeAttrs &bounds) { + IntegerAttr min, max; + std::tie(min, max) = bounds; + if (max) { + APInt maxValue = max.getValue(); + if (maxValue.isNegative()) { + if (min) { + APInt minValue = min.getValue(); + // Special case: negative bounds that ended up read as unsigned + if (minValue.isNegative()) { + APInt smin = minValue.sle(maxValue) ? minValue : maxValue; + APInt smax = minValue.sgt(maxValue) ? minValue : maxValue; + Type type = min.getType(); + return {IntegerAttr::get(type, smin), IntegerAttr::get(type, smax)}; + } + // Special case: bounds from unsigned ops / overflow > SIGNED_MAX + return {{}, {}}; + } + } + } + return {min, max}; +} + /// Invert an integer comparison predicate. arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { switch (pred) { @@ -141,6 +297,14 @@ return getValue(); } +void arith::ConstantOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + // Return null for non-scalar integer constants + auto value = getValue().dyn_cast_or_null(); + resultRanges.push_back({value, value}); +} + void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width) { auto type = builder.getIntegerType(width); @@ -215,6 +379,25 @@ context); } +void arith::AddIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + // FIXME: is there a less manual way to do this + IntegerAttr lhsMin = argRanges[0].first; + IntegerAttr rhsMin = argRanges[1].first; + IntegerAttr lhsMax = argRanges[0].second; + IntegerAttr rhsMax = argRanges[1].second; + + ConstArithFn add = [](auto &a, auto &b) -> llvm::Optional { + bool overflowed = false; + APInt result = a.sadd_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + resultRanges.push_back(computeBoundsBy(add, lhsMin, rhsMin, lhsMax, rhsMax)); +} + //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// @@ -239,6 +422,24 @@ context); } +void arith::SubIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + IntegerAttr lhsMin = argRanges[0].first; + IntegerAttr rhsMin = argRanges[1].first; + IntegerAttr lhsMax = argRanges[0].second; + IntegerAttr rhsMax = argRanges[1].second; + + ConstArithFn sub = [](auto &a, auto &b) -> llvm::Optional { + bool overflowed = false; + APInt result = a.ssub_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + resultRanges.push_back(computeBoundsBy(sub, lhsMin, rhsMax, lhsMax, rhsMin)); +} + //===----------------------------------------------------------------------===// // MulIOp //===----------------------------------------------------------------------===// @@ -257,6 +458,41 @@ operands, [](const APInt &a, const APInt &b) { return a * b; }); } +void arith::MulIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + IntegerAttr lhsMin = argRanges[0].first; + IntegerAttr rhsMin = argRanges[1].first; + IntegerAttr lhsMax = argRanges[0].second; + IntegerAttr rhsMax = argRanges[1].second; + + IntegerAttr min, max; + + bool noNegatives = (lhsMin && rhsMin && lhsMin.getValue().isNonNegative() && + rhsMin.getValue().isNonNegative()); + bool canBoundBelow = + (lhsMin && rhsMin && (noNegatives || (lhsMax && rhsMax))); + bool canBoundAbove = + (lhsMax && rhsMax && (noNegatives || (lhsMin && rhsMin))); + auto mul = [&noNegatives](auto &a, auto &b) -> llvm::Optional { + bool overflowed = false; + APInt result = + noNegatives ? a.umul_ov(b, overflowed) : a.smul_ov(b, overflowed); + if (overflowed) { + LLVM_DEBUG(llvm::dbgs() + << "Multiplying " << result << " := " << a << " * " << b + << " overflowed - noNegatives = " << noNegatives << "\n"); + return {}; + } + return result; + }; + std::tie(min, max) = minMaxBy(mul, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*signedCmp=*/!noNegatives); + + resultRanges.push_back( + {canBoundBelow ? min : nullptr, canBoundAbove ? max : nullptr}); +} + //===----------------------------------------------------------------------===// // DivUIOp //===----------------------------------------------------------------------===// @@ -285,6 +521,32 @@ return div0 ? Attribute() : result; } +static IntRangeAttrs inferDivUIRange(Type resultType, const IntRangeAttrs &lhs, + const IntRangeAttrs &rhs) { + IntegerAttr lhsMin, lhsMax, rhsMin, rhsMax; + unsigned int bitwidth = getAttrBitwidth(resultType); + std::tie(lhsMin, lhsMax) = asUnsigned(resultType, lhs); + std::tie(rhsMin, rhsMax) = asUnsigned(resultType, rhs); + if (rhsMin && !rhsMin.getValue().isZero()) { + if (!rhsMax) // Bound divisor above by 0xffff...fff to get lower bound of 0 + rhsMax = IntegerAttr::get(resultType, APInt::getAllOnesValue(bitwidth)); + ConstArithFn udiv = [](auto &a, auto &b) -> llvm::Optional { + return a.udiv(b); + }; + return computeBoundsBy(udiv, lhsMin, rhsMax, lhsMax, rhsMin); + } + // Otherwise, it's possible we might divide by 0 + return {{}, {}}; +} + +void arith::DivUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type resultType = getResult().getType(); + resultRanges.push_back( + inferDivUIRange(resultType, argRanges[0], argRanges[1])); +} + //===----------------------------------------------------------------------===// // DivSIOp //===----------------------------------------------------------------------===// @@ -313,8 +575,46 @@ return overflowOrDiv0 ? Attribute() : result; } +static IntRangeAttrs inferDivSIRange(Type resultType, const IntRangeAttrs &lhs, + const IntRangeAttrs &rhs) { + IntegerAttr lhsMin, lhsMax, rhsMin, rhsMax; + std::tie(lhsMin, lhsMax) = asSigned(lhs); + std::tie(rhsMin, rhsMax) = asSigned(rhs); + bool canDivide = false; + if (rhsMin && rhsMin.getValue().isStrictlyPositive()) + canDivide = true; + if (rhsMax && rhsMax.getValue().isNegative()) + canDivide = true; + + if (canDivide) { + unsigned int bitwidth = getAttrBitwidth(resultType); + // Unbounded below + negative upper bound -> lower bound = INT_MIN + if (!rhsMin) + rhsMin = IntegerAttr::get(resultType, APInt::getSignedMinValue(bitwidth)); + // Unbounded above + positive lower bound -> upper bound = INT_MAX + if (!rhsMax) + rhsMax = IntegerAttr::get(resultType, APInt::getSignedMaxValue(bitwidth)); + ConstArithFn sdiv = [](auto &a, auto &b) -> llvm::Optional { + bool overflowed = false; + APInt result = a.sdiv_ov(b, overflowed); + if (overflowed) + return {}; + return result; + }; + return computeBoundsBy(sdiv, lhsMin, rhsMax, lhsMax, rhsMin); + } + return {{}, {}}; +} + +void arith::DivSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + resultRanges.push_back( + inferDivSIRange(getResult().getType(), argRanges[0], argRanges[1])); +} + //===----------------------------------------------------------------------===// -// Ceil and floor division folding helpers +// Ceil and floor division helpers //===----------------------------------------------------------------------===// static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, @@ -325,6 +625,44 @@ return val.sadd_ov(one, overflow); } +static IntegerAttr correctCeilDivUIBound(IntegerAttr lhs, IntegerAttr rhs, + IntegerAttr result) { + if (result && lhs && rhs) { + if (!lhs.getValue().urem(rhs.getValue()).isZero()) { + return IntegerAttr::get(result.getType(), result.getValue() + 1); + } + } + return result; +} + +static IntegerAttr correctCeilDivSIBound(IntegerAttr lhs, IntegerAttr rhs, + IntegerAttr result) { + if (result && lhs && rhs) { + APInt lhsValue = lhs.getValue(); + APInt rhsValue = rhs.getValue(); + APInt resultValue = result.getValue(); + if (!lhsValue.srem(rhsValue).isZero() && + lhsValue.isNegative() == rhsValue.isNegative()) { + return IntegerAttr::get(result.getType(), resultValue + 1); + } + } + return result; +} + +static IntegerAttr correctFloorDivBound(IntegerAttr lhs, IntegerAttr rhs, + IntegerAttr result) { + if (result && lhs && rhs) { + APInt lhsValue = lhs.getValue(); + APInt rhsValue = rhs.getValue(); + APInt resultValue = result.getValue(); + if (!lhsValue.srem(rhsValue).isZero() && + lhsValue.isNegative() != rhsValue.isNegative()) { + return IntegerAttr::get(result.getType(), resultValue - 1); + } + } + return result; +} + //===----------------------------------------------------------------------===// // CeilDivUIOp //===----------------------------------------------------------------------===// @@ -356,6 +694,24 @@ return overflowOrDiv0 ? Attribute() : result; } +void arith::CeilDivUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + IntRangeAttrs lhs = argRanges[0]; + IntRangeAttrs rhs = argRanges[1]; + + Type resultType = getResult().getType(); + IntegerAttr min, max; + std::tie(min, max) = inferDivUIRange(resultType, lhs, rhs); + + IntegerAttr lhsMin, lhsMax, rhsMin, rhsMax; + std::tie(lhsMin, lhsMax) = asUnsigned(resultType, lhs); + std::tie(rhsMin, rhsMax) = asUnsigned(resultType, rhs); + min = correctCeilDivUIBound(lhsMin, rhsMax, min); + max = correctCeilDivUIBound(lhsMax, rhsMin, max); + resultRanges.push_back({min, max}); +} + //===----------------------------------------------------------------------===// // CeilDivSIOp //===----------------------------------------------------------------------===// @@ -411,6 +767,24 @@ return overflowOrDiv0 ? Attribute() : result; } +void arith::CeilDivSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + IntRangeAttrs lhs = argRanges[0]; + IntRangeAttrs rhs = argRanges[1]; + + Type resultType = getResult().getType(); + IntegerAttr min, max; + std::tie(min, max) = inferDivSIRange(resultType, lhs, rhs); + + IntegerAttr lhsMin, lhsMax, rhsMin, rhsMax; + std::tie(lhsMin, lhsMax) = asSigned(lhs); + std::tie(rhsMin, rhsMax) = asSigned(rhs); + min = correctCeilDivSIBound(lhsMin, rhsMax, min); + max = correctCeilDivSIBound(lhsMax, rhsMin, max); + resultRanges.push_back({min, max}); +} + //===----------------------------------------------------------------------===// // FloorDivSIOp //===----------------------------------------------------------------------===// @@ -466,6 +840,24 @@ return overflowOrDiv0 ? Attribute() : result; } +void arith::FloorDivSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + IntRangeAttrs lhs = argRanges[0]; + IntRangeAttrs rhs = argRanges[1]; + + Type resultType = getResult().getType(); + IntegerAttr min, max; + std::tie(min, max) = inferDivSIRange(resultType, lhs, rhs); + + IntegerAttr lhsMin, lhsMax, rhsMin, rhsMax; + std::tie(lhsMin, lhsMax) = asSigned(lhs); + std::tie(rhsMin, rhsMax) = asSigned(rhs); + min = correctFloorDivBound(lhsMin, rhsMax, min); + max = correctFloorDivBound(lhsMax, rhsMin, max); + resultRanges.push_back({min, max}); +} + //===----------------------------------------------------------------------===// // RemUIOp //===----------------------------------------------------------------------===// @@ -490,6 +882,37 @@ return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); } +void arith::RemUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + IntegerAttr rhsMin, rhsMax, min, max; + Type resultType = getResult().getType(); + std::tie(rhsMin, rhsMax) = asUnsigned(resultType, argRanges[1]); + + if (rhsMin && rhsMax && !rhsMin.getValue().isZero()) { + APInt maxDivisor = rhsMax.getValue(); + min = + IntegerAttr::get(resultType, APInt::getZero(maxDivisor.getBitWidth())); + max = IntegerAttr::get(resultType, maxDivisor - 1); + // Special case: sweeping out a contiguous range in N/[modulus] + IntegerAttr lhsMin, lhsMax; + std::tie(lhsMin, lhsMax) = asUnsigned(resultType, argRanges[0]); + if (lhsMin && lhsMax && rhsMin == rhsMax) { + APInt minDividend = lhsMin.getValue(); + APInt maxDividend = lhsMax.getValue(); + if ((maxDividend - minDividend).ult(maxDivisor)) { + APInt minRem = minDividend.urem(maxDivisor); + APInt maxRem = maxDividend.urem(maxDivisor); + if (minRem.ule(maxRem)) { + min = IntegerAttr::get(resultType, minRem); + max = IntegerAttr::get(resultType, maxRem); + } + } + } + } + resultRanges.push_back({min, max}); +} + //===----------------------------------------------------------------------===// // RemSIOp //===----------------------------------------------------------------------===// @@ -514,6 +937,74 @@ return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); } +void arith::RemSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + IntegerAttr lhsMin, lhsMax, rhsMin, rhsMax, min, max; + std::tie(lhsMin, lhsMax) = asSigned(argRanges[0]); + std::tie(rhsMin, rhsMax) = asSigned(argRanges[1]); + + Type resultType = getResult().getType(); + // No bounds if zero could be a divisor + bool canBound = rhsMax && rhsMin && + (rhsMin.getValue().isStrictlyPositive() || + rhsMax.getValue().isNegative()); + if (canBound) { + APInt maxDivisor = rhsMin.getValue().isStrictlyPositive() + ? rhsMax.getValue() + : rhsMin.getValue().abs(); + bool canNegativeDividend = !(lhsMin && lhsMin.getValue().isNonNegative()); + bool canPositiveDividend = !(lhsMax && lhsMax.getValue().isNonPositive()); + APInt zero = APInt::getZero(maxDivisor.getBitWidth()); + APInt maxPositiveResult = maxDivisor - 1; + APInt minNegativeResult = -maxPositiveResult; + min = IntegerAttr::get(resultType, + canNegativeDividend ? minNegativeResult : zero); + max = IntegerAttr::get(resultType, + canPositiveDividend ? maxPositiveResult : zero); + // Special case: sweeping out a contiguous range in N/[modulus] + if (lhsMin && lhsMax && rhsMin == rhsMax) { + APInt minDividend = lhsMin.getValue(); + APInt maxDividend = lhsMax.getValue(); + if ((maxDividend - minDividend).ult(maxDivisor)) { + APInt minRem = minDividend.srem(maxDivisor); + APInt maxRem = maxDividend.srem(maxDivisor); + if (minRem.sle(maxRem)) { + min = IntegerAttr::get(resultType, minRem); + max = IntegerAttr::get(resultType, maxRem); + } + } + } + } + resultRanges.push_back({min, max}); +} + +//===----------------------------------------------------------------------===// +// Helpers for implementing range inference on bitwise ops +//===----------------------------------------------------------------------===// + +// "Widen" bounds - if 0bvvvvv??? <= a <= 0bvvvvv???, +// relax the bounds to 0bvvvvv000 <= a <= 0bvvvvv111, where vvvvv are the bits +// that both bonuds have in common. This gives us a consertive approximation for +// what values can be passed to bitwise operations. +// This will widen missing bounds to all zeroes / all ones so we can handle +// [unbounded] & 0xff => [0, 255] +static IntRangeAttrs widenBitwiseBounds(Type resultType, + const IntRangeAttrs &bound) { + unsigned int bitwidth = getAttrBitwidth(resultType); + APInt leftVal = + bound.first ? bound.first.getValue() : APInt::getZero(bitwidth); + APInt rightVal = + bound.second ? bound.second.getValue() : APInt::getAllOnesValue(bitwidth); + unsigned int differingBits = + bitwidth - (leftVal ^ rightVal).countLeadingZeros(); + leftVal.clearLowBits(differingBits); + rightVal.setLowBits(differingBits); + IntegerAttr zeroes = IntegerAttr::get(resultType, leftVal); + IntegerAttr ones = IntegerAttr::get(resultType, rightVal); + return {zeroes, ones}; +} + //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// @@ -531,6 +1022,19 @@ operands, [](APInt a, const APInt &b) { return std::move(a) & b; }); } +void arith::AndIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type resultType = getResult().getType(); + IntegerAttr lhsZeros, lhsOnes, rhsZeros, rhsOnes; + std::tie(lhsZeros, lhsOnes) = + widenBitwiseBounds(resultType, asUnsigned(resultType, argRanges[0])); + std::tie(rhsZeros, rhsOnes) = + widenBitwiseBounds(resultType, asUnsigned(resultType, argRanges[1])); + resultRanges.push_back( + minMaxBy([](auto &a, auto &b) -> llvm::Optional { return a & b; }, + {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, /*signedCmp=*/false)); +} //===----------------------------------------------------------------------===// // OrIOp //===----------------------------------------------------------------------===// @@ -548,6 +1052,20 @@ operands, [](APInt a, const APInt &b) { return std::move(a) | b; }); } +void arith::OrIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type resultType = getResult().getType(); + IntegerAttr lhsZeros, lhsOnes, rhsZeros, rhsOnes; + std::tie(lhsZeros, lhsOnes) = + widenBitwiseBounds(resultType, asUnsigned(resultType, argRanges[0])); + std::tie(rhsZeros, rhsOnes) = + widenBitwiseBounds(resultType, asUnsigned(resultType, argRanges[1])); + resultRanges.push_back( + minMaxBy([](auto &a, auto &b) -> llvm::Optional { return a | b; }, + {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, /*signedCmp=*/false)); +} + //===----------------------------------------------------------------------===// // XOrIOp //===----------------------------------------------------------------------===// @@ -573,6 +1091,20 @@ patterns.add(context); } +void arith::XOrIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type resultType = getResult().getType(); + IntegerAttr lhsZeros, lhsOnes, rhsZeros, rhsOnes; + std::tie(lhsZeros, lhsOnes) = + widenBitwiseBounds(resultType, asUnsigned(resultType, argRanges[0])); + std::tie(rhsZeros, rhsOnes) = + widenBitwiseBounds(resultType, asUnsigned(resultType, argRanges[1])); + resultRanges.push_back( + minMaxBy([](auto &a, auto &b) -> llvm::Optional { return a ^ b; }, + {lhsZeros, lhsOnes}, {rhsZeros, rhsOnes}, /*signedCmp=*/false)); +} + //===----------------------------------------------------------------------===// // NegFOp //===----------------------------------------------------------------------===// @@ -656,6 +1188,29 @@ }); } +void arith::MaxSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + IntegerAttr lhsMin, lhsMax, rhsMin, rhsMax; + std::tie(lhsMin, lhsMax) = asSigned(argRanges[0]); + std::tie(rhsMin, rhsMax) = asSigned(argRanges[1]); + + IntegerAttr min = {}; + // Take the largest lower bound (if any) + if (lhsMin && rhsMin) + min = lhsMin.getValue().sgt(rhsMin.getValue()) ? lhsMin : rhsMin; + else if (lhsMin) + min = lhsMin; + else if (rhsMin) + min = rhsMin; + + // If both upper bounds are present, take their max, be unbounded otherwise + IntegerAttr max = {}; + if (lhsMax && rhsMax) + max = lhsMax.getValue().sgt(rhsMax.getValue()) ? lhsMax : rhsMax; + resultRanges.push_back({min, max}); +} + //===----------------------------------------------------------------------===// // MaxUIOp //===----------------------------------------------------------------------===// @@ -682,6 +1237,30 @@ }); } +void arith::MaxUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type resultType = getResult().getType(); + IntegerAttr lhsMin, lhsMax, rhsMin, rhsMax; + std::tie(lhsMin, lhsMax) = asUnsigned(resultType, argRanges[0]); + std::tie(rhsMin, rhsMax) = asUnsigned(resultType, argRanges[1]); + + IntegerAttr min = {}; + // Take the largest lower bound (if any) + if (lhsMin && rhsMin) + min = lhsMin.getValue().ugt(rhsMin.getValue()) ? lhsMin : rhsMin; + else if (lhsMin) + min = lhsMin; + else if (rhsMin) + min = rhsMin; + + // If both upper bounds are present, take their max, be unbounded otherwise + IntegerAttr max = {}; + if (lhsMax && rhsMax) + max = lhsMax.getValue().ugt(rhsMax.getValue()) ? lhsMax : rhsMax; + resultRanges.push_back({min, max}); +} + //===----------------------------------------------------------------------===// // MinFOp //===----------------------------------------------------------------------===// @@ -730,6 +1309,29 @@ }); } +void arith::MinSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + IntegerAttr lhsMin, lhsMax, rhsMin, rhsMax; + std::tie(lhsMin, lhsMax) = asSigned(argRanges[0]); + std::tie(rhsMin, rhsMax) = asSigned(argRanges[1]); + + IntegerAttr min = {}; + // If both lower bounds are present, take their minimum + if (lhsMin && rhsMin) + min = lhsMin.getValue().slt(rhsMin.getValue()) ? lhsMin : rhsMin; + + // For upper bounds, take the smallest (with absent -> +infinity) + IntegerAttr max = {}; + if (lhsMax && rhsMax) + max = lhsMax.getValue().slt(rhsMax.getValue()) ? lhsMax : rhsMax; + else if (lhsMax) + max = lhsMax; + else if (rhsMax) + max = rhsMax; + resultRanges.push_back({min, max}); +} + //===----------------------------------------------------------------------===// // MinUIOp //===----------------------------------------------------------------------===// @@ -756,6 +1358,30 @@ }); } +void arith::MinUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type resultType = getResult().getType(); + IntegerAttr lhsMin, lhsMax, rhsMin, rhsMax; + std::tie(lhsMin, lhsMax) = asUnsigned(resultType, argRanges[0]); + std::tie(rhsMin, rhsMax) = asUnsigned(resultType, argRanges[1]); + + IntegerAttr min = {}; + // If both lower bounds are present, take their minimum + if (lhsMin && rhsMin) + min = lhsMin.getValue().ult(rhsMin.getValue()) ? lhsMin : rhsMin; + + // For upper bounds, take the smallest (with absent -> +infinity) + IntegerAttr max = {}; + if (lhsMax && rhsMax) + max = lhsMax.getValue().ult(rhsMax.getValue()) ? lhsMax : rhsMax; + else if (lhsMax) + max = lhsMax; + else if (rhsMax) + max = rhsMax; + resultRanges.push_back({min, max}); +} + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// @@ -899,6 +1525,29 @@ return verifyExtOp(*this); } +void arith::ExtUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + unsigned int srcWidth = sourceType.getIntOrFloatBitWidth(); + unsigned int destWidth = destType.getIntOrFloatBitWidth(); + IntegerAttr min, max; + std::tie(min, max) = asUnsigned(sourceType, argRanges[0]); + + if (min) + min = IntegerAttr::get(destType, min.getValue().zext(destWidth)); + else + min = IntegerAttr::get(destType, APInt::getZero(destWidth)); + + if (max) + max = IntegerAttr::get(destType, max.getValue().zext(destWidth)); + else + max = IntegerAttr::get(destType, APInt::getLowBitsSet(destWidth, srcWidth)); + + resultRanges.push_back({min, max}); +} + //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// @@ -933,6 +1582,35 @@ return verifyExtOp(*this); } +static IntRangeAttrs extSIRange(const IntRangeAttrs &range, Type sourceType, + Type destType) { + unsigned int srcWidth = getAttrBitwidth(sourceType); + unsigned int destWidth = getAttrBitwidth(destType); + IntegerAttr min, max; + std::tie(min, max) = asSigned(range); + if (min) + min = IntegerAttr::get(destType, min.getValue().sext(destWidth)); + else + min = IntegerAttr::get( + destType, APInt::getHighBitsSet(destWidth, destWidth - srcWidth + 1)); + + if (max) + max = IntegerAttr::get(destType, max.getValue().sext(destWidth)); + else + max = IntegerAttr::get(destType, + APInt::getLowBitsSet(destWidth, srcWidth - 1)); + + return {min, max}; +} + +void arith::ExtSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + resultRanges.push_back(extSIRange(argRanges[0], sourceType, destType)); +} + //===----------------------------------------------------------------------===// // ExtFOp //===----------------------------------------------------------------------===// @@ -983,6 +1661,26 @@ return verifyTruncateOp(*this); } +static IntRangeAttrs truncIRange(const IntRangeAttrs &range, Type destType) { + IntegerAttr min, max; + std::tie(min, max) = range; + unsigned int destWidth = getAttrBitwidth(destType); + + if (min) + min = IntegerAttr::get(destType, min.getValue().trunc(destWidth)); + if (max) + max = IntegerAttr::get(destType, max.getValue().trunc(destWidth)); + + return {min, max}; +} + +void arith::TruncIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type destType = getResult().getType(); + resultRanges.push_back(truncIRange(argRanges[0], destType)); +} + //===----------------------------------------------------------------------===// // TruncFOp //===----------------------------------------------------------------------===// @@ -1185,6 +1883,22 @@ patterns.add(context); } +void arith::IndexCastOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type sourceType = getOperand().getType(); + Type destType = getResult().getType(); + unsigned int srcWidth = getAttrBitwidth(sourceType); + unsigned int destWidth = getAttrBitwidth(destType); + + if (srcWidth < destWidth) + resultRanges.push_back(extSIRange(argRanges[0], sourceType, destType)); + else if (srcWidth > destWidth) + resultRanges.push_back(truncIRange(argRanges[0], destType)); + else + resultRanges.push_back(argRanges[0]); +} + //===----------------------------------------------------------------------===// // BitcastOp //===----------------------------------------------------------------------===// @@ -1353,6 +2067,74 @@ patterns.insert(context); } +bool isStaticallyTrue(arith::CmpIPredicate pred, IntegerAttr lhsMin, + IntegerAttr lhsMax, IntegerAttr rhsMin, + IntegerAttr rhsMax) { + switch (pred) { + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::slt: + case arith::CmpIPredicate::ult: + return (lhsMax && rhsMin && + applyCmpPredicate(pred, lhsMax.getValue(), rhsMin.getValue())); + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::sgt: + case arith::CmpIPredicate::ugt: + return (lhsMin && rhsMax && + applyCmpPredicate(pred, lhsMin.getValue(), rhsMax.getValue())); + case arith::CmpIPredicate::eq: + // Require equality always - that is, two equal constants + return (lhsMin && lhsMax && rhsMin && rhsMax && lhsMin == lhsMax && + rhsMin == rhsMax && lhsMin == rhsMin); + case arith::CmpIPredicate::ne: + return isStaticallyTrue(arith::CmpIPredicate::slt, lhsMin, lhsMax, rhsMin, + rhsMax) || + isStaticallyTrue(arith::CmpIPredicate::sgt, lhsMin, lhsMax, rhsMin, + rhsMax); + } + return false; +} + +void arith::CmpIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + arith::CmpIPredicate pred = getPredicate(); + Type operandType = getLhs().getType(); + + IntegerAttr lhsMin, lhsMax, rhsMin, rhsMax; + // Force signed/unsigned interpretations if relevant + switch (pred) { + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::slt: + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::sgt: + std::tie(lhsMin, lhsMax) = asSigned(argRanges[0]); + std::tie(rhsMin, rhsMax) = asSigned(argRanges[1]); + break; + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::ult: + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::ugt: + std::tie(lhsMin, lhsMax) = asUnsigned(operandType, argRanges[0]); + std::tie(rhsMin, rhsMax) = asUnsigned(operandType, argRanges[1]); + break; + case arith::CmpIPredicate::eq: + case arith::CmpIPredicate::ne: + std::tie(lhsMin, lhsMax) = argRanges[0]; + std::tie(rhsMin, rhsMax) = argRanges[1]; + } + + IntegerAttr min, max; + if (isStaticallyTrue(pred, lhsMin, lhsMax, rhsMin, rhsMax)) + min = max = IntegerAttr::get(getResult().getType(), 1); + else if (isStaticallyTrue(invertPredicate(pred), lhsMin, lhsMax, rhsMin, + rhsMax)) + min = max = IntegerAttr::get(getResult().getType(), 0); + + resultRanges.push_back({min, max}); +} + //===----------------------------------------------------------------------===// // CmpFOp //===----------------------------------------------------------------------===// @@ -1715,6 +2497,12 @@ patterns.insert(context); } +void arith::CmpFOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + resultRanges.push_back({{}, {}}); +} + //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// @@ -1829,6 +2617,32 @@ return nullptr; } +void arith::SelectOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + IntegerAttr condMin = argRanges[0].first; + IntegerAttr condMax = argRanges[0].second; + IntegerAttr trueMin = argRanges[1].first; + IntegerAttr trueMax = argRanges[1].second; + IntegerAttr falseMin = argRanges[2].first; + IntegerAttr falseMax = argRanges[2].second; + + if (condMin && condMax && condMin == condMax) { + if (condMin.getValue().isZero()) + resultRanges.push_back({falseMin, falseMax}); + else + resultRanges.push_back({trueMin, trueMax}); + return; + } + IntegerAttr min, max; + if (trueMin && falseMin) + min = trueMin.getValue().slt(falseMin.getValue()) ? trueMin : falseMin; + if (trueMax && falseMax) + max = trueMax.getValue().sgt(falseMax.getValue()) ? trueMax : falseMax; + + resultRanges.push_back({min, max}); +} + ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) { Type conditionType, resultType; SmallVector operands; @@ -1896,6 +2710,36 @@ return bounded ? result : Attribute(); } +void arith::ShLIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + + Type resultType = getResult().getType(); + IntegerAttr rhsMin, rhsMax; + std::tie(rhsMin, rhsMax) = asUnsigned(resultType, argRanges[1]); + if (!rhsMax || !rhsMin) { + resultRanges.push_back({{}, {}}); + return; + } + + ConstArithFn shl = [](auto &l, auto &r) -> Optional { + if (r.uge(r.getBitWidth())) + return {}; + return l.shl(r); + }; + IntegerAttr lhsMin, lhsMax, min, max; + std::tie(lhsMin, lhsMax) = argRanges[0]; + bool canBeNegative = !(lhsMin && lhsMin.getValue().isNonNegative()); + std::tie(min, max) = minMaxBy(shl, {lhsMin, lhsMax}, {rhsMin, rhsMax}, + /*signedCmp=*/canBeNegative); + if (!lhsMin) + min = {}; + if (!lhsMax) + max = {}; + + resultRanges.push_back({min, max}); +} + //===----------------------------------------------------------------------===// // ShRUIOp //===----------------------------------------------------------------------===// @@ -1911,6 +2755,34 @@ return bounded ? result : Attribute(); } +void arith::ShRUIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type resultType = getResult().getType(); + IntegerAttr lhsMin, lhsMax, rhsMin, rhsMax; + std::tie(lhsMin, lhsMax) = asUnsigned(resultType, argRanges[0]); + std::tie(rhsMin, rhsMax) = asUnsigned(resultType, argRanges[1]); + if (!rhsMax || !rhsMin) { + resultRanges.push_back({{}, {}}); + return; + } + + ConstArithFn lshr = [](auto &l, auto &r) -> Optional { + if (r.uge(r.getBitWidth())) + return {}; + return l.lshr(r); + }; + IntegerAttr min, max; + std::tie(min, max) = + minMaxBy(lshr, {lhsMin, lhsMax}, {rhsMin, rhsMax}, /*signedCmp=*/false); + if (!lhsMin) + min = {}; + if (!lhsMax) + max = {}; + + resultRanges.push_back({min, max}); +} + //===----------------------------------------------------------------------===// // ShRSIOp //===----------------------------------------------------------------------===// @@ -1926,6 +2798,34 @@ return bounded ? result : Attribute(); } +void arith::ShRSIOp::inferResultRanges( + ArrayRef argRanges, + SmallVectorImpl &resultRanges) { + Type resultType = getResult().getType(); + IntegerAttr lhsMin, lhsMax, rhsMin, rhsMax; + std::tie(lhsMin, lhsMax) = asSigned(argRanges[0]); + std::tie(rhsMin, rhsMax) = asUnsigned(resultType, argRanges[1]); + if (!rhsMax || !rhsMin) { + resultRanges.push_back({{}, {}}); + return; + } + + ConstArithFn ashr = [](auto &l, auto &r) -> Optional { + if (r.uge(r.getBitWidth())) + return {}; + return l.ashr(r); + }; + IntegerAttr min, max; + std::tie(min, max) = + minMaxBy(ashr, {lhsMin, lhsMax}, {rhsMin, rhsMax}, /*signedCmp=*/true); + if (!lhsMin) + min = {}; + if (!lhsMax) + max = {}; + + resultRanges.push_back({min, max}); +} + //===----------------------------------------------------------------------===// // Atomic Enum //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt @@ -5,12 +5,14 @@ add_mlir_dialect_library(MLIRArithmetic ArithmeticOps.cpp ArithmeticDialect.cpp + InferIntRangeInterface.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic DEPENDS MLIRArithmeticOpsIncGen + MLIRInferIntRangeInterfaceIncGen LINK_LIBS PUBLIC MLIRDialect diff --git a/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterface.cpp b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterface.cpp @@ -0,0 +1,10 @@ +//===- InferIntRangeInterface.cpp - Integer range inference interface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h" +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.cpp.inc"