diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp --- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp +++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp @@ -10,9 +10,11 @@ #include "mlir/Dialect/Index/IR/IndexAttrs.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "llvm/ADT/SmallString.h" -#include +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::index; @@ -313,9 +315,10 @@ //===----------------------------------------------------------------------===// OpFoldResult MinSOp::fold(FoldAdaptor adaptor) { - return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { - return lhs.slt(rhs) ? lhs : rhs; - }); + return foldBinaryOpChecked(adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { + return lhs.slt(rhs) ? lhs : rhs; + }); } //===----------------------------------------------------------------------===// @@ -323,9 +326,10 @@ //===----------------------------------------------------------------------===// OpFoldResult MinUOp::fold(FoldAdaptor adaptor) { - return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) { - return lhs.ult(rhs) ? lhs : rhs; - }); + return foldBinaryOpChecked(adaptor.getOperands(), + [](const APInt &lhs, const APInt &rhs) { + return lhs.ult(rhs) ? lhs : rhs; + }); } //===----------------------------------------------------------------------===// @@ -455,19 +459,65 @@ llvm_unreachable("unhandled IndexCmpPredicate predicate"); } +/// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the +/// values of `cstA` and `cstB`, the max or min operation, and the comparison +/// predicate. Check whether the value folds in both 32-bit and 64-bit +/// arithmetic and to the same value. +static std::optional foldCmpOfMaxOrMin(Operation *lhsOp, + const APInt &cstA, + const APInt &cstB, unsigned width, + IndexCmpPredicate pred) { + ConstantIntRanges lhsRange = TypeSwitch(lhsOp) + .Case([&](MinSOp op) { + return ConstantIntRanges::fromSigned( + APInt::getSignedMinValue(width), cstA); + }) + .Case([&](MinUOp op) { + return ConstantIntRanges::fromUnsigned( + APInt::getMinValue(width), cstA); + }) + .Case([&](MaxSOp op) { + return ConstantIntRanges::fromSigned( + cstA, APInt::getSignedMaxValue(width)); + }) + .Case([&](MaxUOp op) { + return ConstantIntRanges::fromUnsigned( + cstA, APInt::getMaxValue(width)); + }); + return intrange::evaluatePred(static_cast(pred), + lhsRange, ConstantIntRanges::constant(cstB)); +} + OpFoldResult CmpOp::fold(FoldAdaptor adaptor) { + // Attempt to fold if both inputs are constant. auto lhs = dyn_cast_if_present(adaptor.getLhs()); auto rhs = dyn_cast_if_present(adaptor.getRhs()); - if (!lhs || !rhs) - return {}; + if (lhs && rhs) { + // Perform the comparison in 64-bit and 32-bit. + bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred()); + bool result32 = compareIndices(lhs.getValue().trunc(32), + rhs.getValue().trunc(32), getPred()); + if (result64 == result32) + return BoolAttr::get(getContext(), result64); + } - // Perform the comparison in 64-bit and 32-bit. - bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred()); - bool result32 = compareIndices(lhs.getValue().trunc(32), - rhs.getValue().trunc(32), getPred()); - if (result64 != result32) - return {}; - return BoolAttr::get(getContext(), result64); + // Fold `cmp(max/min(x, cstA), cstB)`. + Operation *lhsOp = getLhs().getDefiningOp(); + IntegerAttr cstA, cstB; + if (isa_and_nonnull(lhsOp) && + matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && + matchPattern(getRhs(), m_Constant(&cstB))) { + std::optional result64 = foldCmpOfMaxOrMin( + lhsOp, cstA.getValue(), cstB.getValue(), 64, getPred()); + std::optional result32 = + foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32), + cstB.getValue().trunc(32), 32, getPred()); + // Fold if the 32-bit and 64-bit results are the same. + if (result64 && result32 && *result64 == *result32) + return BoolAttr::get(getContext(), *result64); + } + + return {}; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir --- a/mlir/test/Dialect/Index/index-canonicalize.mlir +++ b/mlir/test/Dialect/Index/index-canonicalize.mlir @@ -510,3 +510,14 @@ // CHECK: return %[[TRUE]] return %0 : i1 } + +// CHECK-LABEL: @cmp_maxs +func.func @cmp_maxs(%arg0: index) -> (i1, i1) { + %idx0 = index.constant 0 + %idx1 = index.constant 1 + %0 = index.maxs %arg0, %idx1 + %1 = index.cmp sgt(%0, %idx0) + %2 = index.cmp eq(%0, %idx0) + // CHECK: return %true, %false + return %1, %2 : i1, i1 +}