diff --git a/llvm/lib/Target/ARM/MVETailPredication.cpp b/llvm/lib/Target/ARM/MVETailPredication.cpp --- a/llvm/lib/Target/ARM/MVETailPredication.cpp +++ b/llvm/lib/Target/ARM/MVETailPredication.cpp @@ -457,13 +457,10 @@ // upperbound(TC) <= UINT_MAX - VectorWidth // unsigned SizeInBits = TripCount->getType()->getScalarSizeInBits(); - auto Diff = APInt(SizeInBits, ~0) - APInt(SizeInBits, VectorWidth); - uint64_t MaxMinusVW = Diff.getZExtValue(); - // FIXME: since ranges can be negative we work with signed ranges here, but - // we shouldn't extract the zext'ed values for them. - uint64_t UpperboundTC = SE->getSignedRange(TC).getUpper().getZExtValue(); + auto MaxMinusVW = APInt(SizeInBits, ~0) - APInt(SizeInBits, VectorWidth); + APInt UpperboundTC = SE->getUnsignedRangeMax(TC); - if (UpperboundTC > MaxMinusVW && !ForceTailPredication) { + if (UpperboundTC.ugt(MaxMinusVW) && !ForceTailPredication) { LLVM_DEBUG(dbgs() << "ARM TP: Overflow possible in tripcount rounding:\n"; dbgs() << "upperbound(TC) <= UINT_MAX - VectorWidth\n"; dbgs() << UpperboundTC << " <= " << MaxMinusVW << " == false\n";); @@ -501,8 +498,8 @@ auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth))); - ConstantRange RangeCeil = SE->getSignedRange(Ceil) ; - ConstantRange RangeTC = SE->getSignedRange(TC) ; + ConstantRange RangeCeil = SE->getUnsignedRange(Ceil) ; + ConstantRange RangeTC = SE->getUnsignedRange(TC) ; if (!RangeTC.isSingleElement()) { auto ZeroRange = ConstantRange(APInt(TripCount->getType()->getScalarSizeInBits(), 0));