diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1001,10 +1001,9 @@ return nullptr; } -static std::optional getKnownSign(Value *Op, Instruction *CxtI, - const DataLayout &DL, AssumptionCache *AC, - DominatorTree *DT) { - KnownBits Known = computeKnownBits(Op, DL, 0, AC, CxtI, DT); +std::optional InstCombinerImpl::getKnownSign(Value *Op, + Instruction *CxtI) const { + KnownBits Known = llvm::computeKnownBits(Op, DL, /*Depth*/ 0, &AC, CxtI, &DT); if (Known.isNonNegative()) return false; if (Known.isNegative()) @@ -1018,11 +1017,9 @@ ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL); } -static std::optional getKnownSignOrZero(Value *Op, Instruction *CxtI, - const DataLayout &DL, - AssumptionCache *AC, - DominatorTree *DT) { - if (std::optional Sign = getKnownSign(Op, CxtI, DL, AC, DT)) +std::optional +InstCombinerImpl::getKnownSignOrZero(Value *Op, Instruction *CxtI) const { + if (std::optional Sign = getKnownSign(Op, CxtI)) return Sign; Value *X, *Y; @@ -1034,12 +1031,11 @@ /// Return true if two values \p Op0 and \p Op1 are known to have the same sign. static bool signBitMustBeTheSame(Value *Op0, Value *Op1, Instruction *CxtI, - const DataLayout &DL, AssumptionCache *AC, - DominatorTree *DT) { - std::optional Known1 = getKnownSign(Op1, CxtI, DL, AC, DT); + InstCombinerImpl &IC) { + std::optional Known1 = IC.getKnownSign(Op1, CxtI); if (!Known1) return false; - std::optional Known0 = getKnownSign(Op0, CxtI, DL, AC, DT); + std::optional Known0 = IC.getKnownSign(Op0, CxtI); if (!Known0) return false; return *Known0 == *Known1; @@ -1532,8 +1528,7 @@ if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X)))) return replaceOperand(*II, 0, X); - if (std::optional Known = - getKnownSignOrZero(IIOperand, II, DL, &AC, &DT)) { + if (std::optional Known = getKnownSignOrZero(IIOperand, II)) { // abs(x) -> x if x >= 0 (include abs(x-y) --> x - y where x >= y) // abs(x) -> x if x > 0 (include abs(x-y) --> x - y where x > y) if (!*Known) @@ -1646,7 +1641,7 @@ bool UseAndN = IID == Intrinsic::smin || IID == Intrinsic::umin; if (IID == Intrinsic::smax || IID == Intrinsic::smin) { - auto KnownSign = getKnownSign(X, II, DL, &AC, &DT); + auto KnownSign = getKnownSign(X, II); if (KnownSign == std::nullopt) { UseOr = false; UseAndN = false; @@ -2419,7 +2414,7 @@ FastMathFlags InnerFlags = cast(Src)->getFastMathFlags(); if ((FMF.allowReassoc() && InnerFlags.allowReassoc()) || - signBitMustBeTheSame(Exp, InnerExp, II, DL, &AC, &DT)) { + signBitMustBeTheSame(Exp, InnerExp, II, *this)) { // TODO: Add nsw/nuw probably safe if integer type exceeds exponent // width. Value *NewExp = Builder.CreateAdd(InnerExp, Exp); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -482,6 +482,13 @@ Instruction::BinaryOps BinaryOp, bool IsSigned, Value *LHS, Value *RHS, Instruction *CxtI) const; + // Return true if known negative, false if known positive, and nullopt if + // unknown. + std::optional getKnownSign(Value *Op, Instruction *CxtI) const; + // Return true if known negative or zero, false if known non-zero positive, + // and nullopt if unknown. + std::optional getKnownSignOrZero(Value *Op, Instruction *CxtI) const; + /// Performs a few simplifications for operators which are associative /// or commutative. bool SimplifyAssociativeOrCommutative(BinaryOperator &I);