diff --git a/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/InferIntRangeInterfaceImpls.cpp @@ -503,10 +503,37 @@ static ConstantIntRanges truncIRange(const ConstantIntRanges &range, Type destType) { unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); - APInt umin = range.umin().trunc(destWidth); - APInt umax = range.umax().trunc(destWidth); - APInt smin = range.smin().trunc(destWidth); - APInt smax = range.smax().trunc(destWidth); + // If you truncate the first four bytes in [0xaaaabbbb, 0xccccbbbb], + // the range of the resulting value is not contiguous ind includes 0. + // Ex. If you truncate [256, 258] from i16 to i8, you validly get [0, 2], + // but you can't truncate [255, 257] similarly. + bool hasUnsignedRollover = + range.umin().lshr(destWidth) != range.umax().lshr(destWidth); + APInt umin = hasUnsignedRollover ? APInt::getZero(destWidth) + : range.umin().trunc(destWidth); + APInt umax = hasUnsignedRollover ? APInt::getMaxValue(destWidth) + : range.umax().trunc(destWidth); + + // Signed post-truncation rollover will not occur when either: + // - The high parts of the min and max, plus the sign bit, are the same + // - The high halves + sign bit of the min and max are either all 1s or all 0s + // and you won't create a [positive, negative] range by truncating. + // For example, you can truncate the ranges [256, 258]_i16 to [0, 2]_i8 + // but not [255, 257]_i16 to a range of i8s. You can also truncate + // [-256, -256]_i16 to [-2, 0]_i8, but not [-257, -255]_i16. + // You can also truncate [-130, 0]_i16 to i8 because -130_i16 (0xff7e) + // will truncate to 0x7e, which is greater than 0 + APInt sminHighPart = range.smin().ashr(destWidth - 1); + APInt smaxHighPart = range.smax().ashr(destWidth - 1); + bool hasSignedOverflow = + (sminHighPart != smaxHighPart) && + !(sminHighPart.isAllOnes() && + (smaxHighPart.isAllOnes() || smaxHighPart.isZero())) && + !(sminHighPart.isZero() && smaxHighPart.isZero()); + APInt smin = hasSignedOverflow ? APInt::getSignedMinValue(destWidth) + : range.smin().trunc(destWidth); + APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth) + : range.smax().trunc(destWidth); return {umin, umax, smin, smax}; } diff --git a/mlir/test/Dialect/Arithmetic/int-range-interface.mlir b/mlir/test/Dialect/Arithmetic/int-range-interface.mlir --- a/mlir/test/Dialect/Arithmetic/int-range-interface.mlir +++ b/mlir/test/Dialect/Arithmetic/int-range-interface.mlir @@ -463,14 +463,15 @@ %c-14_i16 = arith.constant -14 : i16 %ci16_smin = arith.constant 0xffff8000 : i32 %0 = arith.minsi %arg0, %c-14_i32 : i32 - %1 = arith.trunci %0 : i32 to i16 - %2 = arith.cmpi sle, %1, %c-14_i16 : i16 - %3 = arith.extsi %1 : i16 to i32 - %4 = arith.cmpi sle, %3, %c-14_i32 : i32 - %5 = arith.cmpi sge, %3, %ci16_smin : i32 - %6 = arith.andi %2, %4 : i1 - %7 = arith.andi %6, %5 : i1 - func.return %7 : i1 + %1 = arith.maxsi %0, %ci16_smin : i32 + %2 = arith.trunci %1 : i32 to i16 + %3 = arith.cmpi sle, %2, %c-14_i16 : i16 + %4 = arith.extsi %2 : i16 to i32 + %5 = arith.cmpi sle, %4, %c-14_i32 : i32 + %6 = arith.cmpi sge, %4, %ci16_smin : i32 + %7 = arith.andi %3, %5 : i1 + %8 = arith.andi %7, %6 : i1 + func.return %8 : i1 } // CHECK-LABEL: func @index_cast @@ -645,3 +646,69 @@ func.return %8 : i1 } +// Test fon a bug where the noive implementation of trunctation led to the cast +// value being set to [0, 0]. +// CHECK-LABEL: func.func @truncation_spillover +// CHECK: %[[unreplaced:.*]] = arith.index_cast +// CHECK: memref.store %[[unreplaced]] +func.func @truncation_spillover(%arg0 : memref) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c49 = arith.constant 49 : index + %0 = scf.for %arg1 = %c0 to %c2 step %c1 iter_args(%arg2 = %c0) -> index { + %1 = arith.divsi %arg2, %c49 : index + %2 = arith.index_cast %1 : index to i32 + memref.store %2, %arg0[%c0] : memref + %3 = arith.addi %arg2, %arg1 : index + scf.yield %3 : index + } + func.return %0 : index +} + +// CHECK-LABEL: func.func @trunc_catches_overflow +// CHECK: %[[sge:.*]] = arith.cmpi sge +// CHECK: return %[[sge]] +func.func @trunc_catches_overflow(%arg0 : i16) -> i1 { + %c0_i16 = arith.constant 0 : i16 + %c130_i16 = arith.constant 130 : i16 + %c0_i8 = arith.constant 0 : i8 + %0 = arith.maxui %arg0, %c0_i16 : i16 + %1 = arith.minui %0, %c130_i16 : i16 + %2 = arith.trunci %1 : i16 to i8 + %3 = arith.cmpi sge, %2, %c0_i8 : i8 + %4 = arith.cmpi uge, %2, %c0_i8 : i8 + %5 = arith.andi %3, %4 : i1 + func.return %5 : i1 +} + +// CHECK-LABEL: func.func @trunc_respects_same_high_half +// CHECK: %[[false:.*]] = arith.constant false +// CHECK: return %[[false]] +func.func @trunc_respects_same_high_half(%arg0 : i16) -> i1 { + %c256_i16 = arith.constant 256 : i16 + %c257_i16 = arith.constant 257 : i16 + %c2_i8 = arith.constant 2 : i8 + %0 = arith.maxui %arg0, %c256_i16 : i16 + %1 = arith.minui %0, %c257_i16 : i16 + %2 = arith.trunci %1 : i16 to i8 + %3 = arith.cmpi sge, %2, %c2_i8 : i8 + func.return %3 : i1 +} + +// CHECK-LABEL: func.func @trunc_handles_small_signed_ranges +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: return %[[true]] +func.func @trunc_handles_small_signed_ranges(%arg0 : i16) -> i1 { + %c-2_i16 = arith.constant -2 : i16 + %c2_i16 = arith.constant 2 : i16 + %c-2_i8 = arith.constant -2 : i8 + %c2_i8 = arith.constant 2 : i8 + %0 = arith.maxsi %arg0, %c-2_i16 : i16 + %1 = arith.minsi %0, %c2_i16 : i16 + %2 = arith.trunci %1 : i16 to i8 + %3 = arith.cmpi sge, %2, %c-2_i8 : i8 + %4 = arith.cmpi sle, %2, %c2_i8 : i8 + %5 = arith.andi %3, %4 : i1 + func.return %5 : i1 +}