diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -1346,6 +1346,19 @@ /// that is greater than or equal to the current width. APInt trunc(unsigned width) const; + /// Truncate to new width with unsigned saturation. + /// + /// If the APInt, treated as unsigned integer, can be losslessly truncated to + /// the new bitwidth, then return truncated APInt. Else, return max value. + APInt truncUSat(unsigned width) const; + + /// Truncate to new width with signed saturation. + /// + /// If this APInt, treated as signed integer, can be losslessly truncated to + /// the new bitwidth, then return truncated APInt. Else, return either + /// signed min value if the APInt was negative, or signed max value. + APInt truncSSat(unsigned width) const; + /// Sign extend to a new width. /// /// This operation sign extends the APInt to a new width. If the high order diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -884,6 +884,31 @@ return Result; } +// Truncate to new width with unsigned saturation. +APInt APInt::truncUSat(unsigned width) const { + assert(width < BitWidth && "Invalid APInt Truncate request"); + assert(width && "Can't truncate to 0 bits"); + + // Can we just losslessly truncate it? + if (isIntN(width)) + return trunc(width); + // If not, then just return the new limit. + return APInt::getMaxValue(width); +} + +// Truncate to new width with signed saturation. +APInt APInt::truncSSat(unsigned width) const { + assert(width < BitWidth && "Invalid APInt Truncate request"); + assert(width && "Can't truncate to 0 bits"); + + // Can we just losslessly truncate it? + if (isSignedIntN(width)) + return trunc(width); + // If not, then just return the new limits. + return isNegative() ? APInt::getSignedMinValue(width) + : APInt::getSignedMaxValue(width); +} + // Sign extend to a new width. APInt APInt::sext(unsigned Width) const { assert(Width > BitWidth && "Invalid APInt SignExtend request"); diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -1177,9 +1177,26 @@ TEST(APIntTest, SaturatingMath) { APInt AP_10 = APInt(8, 10); + APInt AP_42 = APInt(8, 42); APInt AP_100 = APInt(8, 100); APInt AP_200 = APInt(8, 200); + EXPECT_EQ(APInt(7, 100), AP_100.truncUSat(7)); + EXPECT_EQ(APInt(6, 63), AP_100.truncUSat(6)); + EXPECT_EQ(APInt(5, 31), AP_100.truncUSat(5)); + + EXPECT_EQ(APInt(7, 127), AP_200.truncUSat(7)); + EXPECT_EQ(APInt(6, 63), AP_200.truncUSat(6)); + EXPECT_EQ(APInt(5, 31), AP_200.truncUSat(5)); + + EXPECT_EQ(APInt(7, 42), AP_42.truncSSat(7)); + EXPECT_EQ(APInt(6, 31), AP_42.truncSSat(6)); + EXPECT_EQ(APInt(5, 15), AP_42.truncSSat(5)); + + EXPECT_EQ(APInt(7, -56), AP_200.truncSSat(7)); + EXPECT_EQ(APInt(6, -32), AP_200.truncSSat(6)); + EXPECT_EQ(APInt(5, -16), AP_200.truncSSat(5)); + EXPECT_EQ(APInt(8, 200), AP_100.uadd_sat(AP_100)); EXPECT_EQ(APInt(8, 255), AP_100.uadd_sat(AP_200)); EXPECT_EQ(APInt(8, 255), APInt(8, 255).uadd_sat(APInt(8, 255)));