diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -1109,6 +1109,8 @@ APInt uadd_sat(const APInt &RHS) const; APInt ssub_sat(const APInt &RHS) const; APInt usub_sat(const APInt &RHS) const; + APInt smul_sat(const APInt &RHS) const; + APInt umul_sat(const APInt &RHS) const; /// Array-indexing support. /// diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -2048,6 +2048,27 @@ return APInt(BitWidth, 0); } +APInt APInt::smul_sat(const APInt &RHS) const { + bool Overflow; + APInt Res = smul_ov(RHS, Overflow); + if (!Overflow) + return Res; + + // The result is negative if one and only one of inputs is negative. + bool ResIsNegative = isNegative() ^ RHS.isNegative(); + + return ResIsNegative ? APInt::getSignedMinValue(BitWidth) + : APInt::getSignedMaxValue(BitWidth); +} + +APInt APInt::umul_sat(const APInt &RHS) const { + bool Overflow; + APInt Res = umul_ov(RHS, Overflow); + if (!Overflow) + return Res; + + return APInt::getMaxValue(BitWidth); +} void APInt::fromString(unsigned numbits, StringRef str, uint8_t radix) { // Check our assumptions here diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -1197,6 +1197,24 @@ EXPECT_EQ(APInt(8, 127), AP_100.ssub_sat(-AP_100)); EXPECT_EQ(APInt(8, -128), (-AP_100).ssub_sat(AP_100)); EXPECT_EQ(APInt(8, -128), APInt(8, -128).ssub_sat(APInt(8, 127))); + + EXPECT_EQ(APInt(8, 250), APInt(8, 50).umul_sat(APInt(8, 5))); + EXPECT_EQ(APInt(8, 255), APInt(8, 50).umul_sat(APInt(8, 6))); + EXPECT_EQ(APInt(8, 255), APInt(8, -128).umul_sat(APInt(8, 3))); + EXPECT_EQ(APInt(8, 255), APInt(8, 3).umul_sat(APInt(8, -128))); + EXPECT_EQ(APInt(8, 255), APInt(8, -128).umul_sat(APInt(8, -128))); + + EXPECT_EQ(APInt(8, 125), APInt(8, 25).smul_sat(APInt(8, 5))); + EXPECT_EQ(APInt(8, 127), APInt(8, 25).smul_sat(APInt(8, 6))); + EXPECT_EQ(APInt(8, 127), APInt(8, 127).smul_sat(APInt(8, 127))); + EXPECT_EQ(APInt(8, -125), APInt(8, -25).smul_sat(APInt(8, 5))); + EXPECT_EQ(APInt(8, -125), APInt(8, 25).smul_sat(APInt(8, -5))); + EXPECT_EQ(APInt(8, 125), APInt(8, -25).smul_sat(APInt(8, -5))); + EXPECT_EQ(APInt(8, 125), APInt(8, 25).smul_sat(APInt(8, 5))); + EXPECT_EQ(APInt(8, -128), APInt(8, -25).smul_sat(APInt(8, 6))); + EXPECT_EQ(APInt(8, -128), APInt(8, 25).smul_sat(APInt(8, -6))); + EXPECT_EQ(APInt(8, 127), APInt(8, -25).smul_sat(APInt(8, -6))); + EXPECT_EQ(APInt(8, 127), APInt(8, 25).smul_sat(APInt(8, 6))); } TEST(APIntTest, FromArray) {