diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h --- a/llvm/include/llvm/IR/ConstantRange.h +++ b/llvm/include/llvm/IR/ConstantRange.h @@ -434,6 +434,12 @@ /// Perform a signed saturating subtraction of two constant ranges. ConstantRange ssub_sat(const ConstantRange &Other) const; + /// Perform an unsigned saturating multiplication of two constant ranges. + ConstantRange umul_sat(const ConstantRange &Other) const; + + /// Perform a signed saturating multiplication of two constant ranges. + ConstantRange smul_sat(const ConstantRange &Other) const; + /// Perform an unsigned saturating left shift of this constant range by a /// value in \p Other. ConstantRange ushl_sat(const ConstantRange &Other) const; diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp --- a/llvm/lib/IR/ConstantRange.cpp +++ b/llvm/lib/IR/ConstantRange.cpp @@ -1333,6 +1333,41 @@ return getNonEmpty(std::move(NewL), std::move(NewU)); } +ConstantRange ConstantRange::umul_sat(const ConstantRange &Other) const { + if (isEmptySet() || Other.isEmptySet()) + return getEmpty(); + + APInt NewL = getUnsignedMin().umul_sat(Other.getUnsignedMin()); + APInt NewU = getUnsignedMax().umul_sat(Other.getUnsignedMax()) + 1; + return getNonEmpty(std::move(NewL), std::move(NewU)); +} + +ConstantRange ConstantRange::smul_sat(const ConstantRange &Other) const { + if (isEmptySet() || Other.isEmptySet()) + return getEmpty(); + + // Because we could be dealing with negative numbers here, the lower bound is + // the smallest of the cartesian product of the lower and upper ranges; + // for example: + // [-1,4) * [-2,3) = min(-1*-2, -1*2, 3*-2, 3*2) = -6. + // Similarly for the upper bound, swapping min for max. + + APInt this_min = getSignedMin().sext(getBitWidth() * 2); + APInt this_max = getSignedMax().sext(getBitWidth() * 2); + APInt Other_min = Other.getSignedMin().sext(getBitWidth() * 2); + APInt Other_max = Other.getSignedMax().sext(getBitWidth() * 2); + + auto L = {this_min * Other_min, this_min * Other_max, this_max * Other_min, + this_max * Other_max}; + auto Compare = [](const APInt &A, const APInt &B) { return A.slt(B); }; + + // Note that we wanted to perform signed saturating multiplication, + // so since we performed plain multiplication in twice the bitwidth, + // we need to perform signed saturating truncation. + return getNonEmpty(std::min(L, Compare).truncSSat(getBitWidth()), + std::max(L, Compare).truncSSat(getBitWidth()) + 1); +} + ConstantRange ConstantRange::ushl_sat(const ConstantRange &Other) const { if (isEmptySet() || Other.isEmptySet()) return getEmpty(); diff --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp --- a/llvm/unittests/IR/ConstantRangeTest.cpp +++ b/llvm/unittests/IR/ConstantRangeTest.cpp @@ -2217,6 +2217,14 @@ }); } +TEST_F(ConstantRangeTest, UMulSat) { + TestUnsignedBinOpExhaustive( + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.umul_sat(CR2); + }, + [](const APInt &N1, const APInt &N2) { return N1.umul_sat(N2); }); +} + TEST_F(ConstantRangeTest, UShlSat) { TestUnsignedBinOpExhaustive( [](const ConstantRange &CR1, const ConstantRange &CR2) { @@ -2245,6 +2253,14 @@ }); } +TEST_F(ConstantRangeTest, SMulSat) { + TestSignedBinOpExhaustive( + [](const ConstantRange &CR1, const ConstantRange &CR2) { + return CR1.smul_sat(CR2); + }, + [](const APInt &N1, const APInt &N2) { return N1.smul_sat(N2); }); +} + TEST_F(ConstantRangeTest, SShlSat) { TestSignedBinOpExhaustive( [](const ConstantRange &CR1, const ConstantRange &CR2) {