Index: ScalarEvolution.cpp =================================================================== --- ScalarEvolution.cpp +++ ScalarEvolution.cpp @@ -843,49 +843,49 @@ struct SCEVDivision : public SCEVVisitor { public: - // Computes the Quotient and Remainder of the division of Numerator by - // Denominator. - static void divide(ScalarEvolution &SE, const SCEV *Numerator, - const SCEV *Denominator, const SCEV **Quotient, + // Computes the Quotient and Remainder of the division of Dividend by + // Divisor. + static void divide(ScalarEvolution &SE, const SCEV *Dividend, + const SCEV *Divisor, const SCEV **Quotient, const SCEV **Remainder) { - assert(Numerator && Denominator && "Uninitialized SCEV"); + assert(Dividend && Divisor && "Uninitialized SCEV"); - SCEVDivision D(SE, Numerator, Denominator); + SCEVDivision D(SE, Dividend, Divisor); // Check for the trivial case here to avoid having to check for it in the // rest of the code. - if (Numerator == Denominator) { + if (Dividend == Divisor) { *Quotient = D.One; *Remainder = D.Zero; return; } - if (Numerator->isZero()) { + if (Dividend->isZero()) { *Quotient = D.Zero; *Remainder = D.Zero; return; } // A simple case when N/1. The quotient is N. - if (Denominator->isOne()) { - *Quotient = Numerator; + if (Divisor->isOne()) { + *Quotient = Dividend; *Remainder = D.Zero; return; } - // Split the Denominator when it is a product. - if (const SCEVMulExpr *T = dyn_cast(Denominator)) { + // Split the Divisor when it is a product. + if (const SCEVMulExpr *T = dyn_cast(Divisor)) { const SCEV *Q, *R; - *Quotient = Numerator; + *Quotient = Dividend; for (const SCEV *Op : T->operands()) { divide(SE, *Quotient, Op, &Q, &R); *Quotient = Q; - // Bail out when the Numerator is not divisible by one of the terms of - // the Denominator. + // Bail out when the Dividend is not divisible by one of the terms of + // the Divisor. if (!R->isZero()) { *Quotient = D.Zero; - *Remainder = Numerator; + *Remainder = Dividend; return; } } @@ -893,71 +893,118 @@ return; } - D.visit(Numerator); + D.visit(Dividend); *Quotient = D.Quotient; *Remainder = D.Remainder; } // Except in the trivial case described above, we do not know how to divide - // Expr by Denominator for the following functions with empty implementation. - void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} - void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} - void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} - void visitUDivExpr(const SCEVUDivExpr *Numerator) {} - void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} - void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} - void visitUnknown(const SCEVUnknown *Numerator) {} - void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} - - void visitConstant(const SCEVConstant *Numerator) { - if (const SCEVConstant *D = dyn_cast(Denominator)) { - APInt NumeratorVal = Numerator->getAPInt(); - APInt DenominatorVal = D->getAPInt(); - uint32_t NumeratorBW = NumeratorVal.getBitWidth(); - uint32_t DenominatorBW = DenominatorVal.getBitWidth(); - - if (NumeratorBW > DenominatorBW) - DenominatorVal = DenominatorVal.sext(NumeratorBW); - else if (NumeratorBW < DenominatorBW) - NumeratorVal = NumeratorVal.sext(DenominatorBW); - - APInt QuotientVal(NumeratorVal.getBitWidth(), 0); - APInt RemainderVal(NumeratorVal.getBitWidth(), 0); - APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); + // Expr by Divisor for the following functions with empty implementation. + void visitUDivExpr(const SCEVUDivExpr *Dividend) {} + void visitSMaxExpr(const SCEVSMaxExpr *Dividend) {} + void visitUMaxExpr(const SCEVUMaxExpr *Dividend) {} + void visitUnknown(const SCEVUnknown *Dividend) {} + void visitCouldNotCompute(const SCEVCouldNotCompute *Dividend) {} + + void treatCastDivisor(const SCEVCastExpr *CEDividend, const SCEV **Q, + const SCEV **R) { + if (const SCEVTruncateExpr *TruncDivisor = + dyn_cast(Divisor)) { + divide(SE, CEDividend->getOperand(), TruncDivisor->getOperand(), Q, R); + } + else + if (const SCEVZeroExtendExpr *ZExtDivisor = + dyn_cast(Divisor)) { + divide(SE, CEDividend->getOperand(), ZExtDivisor->getOperand(), Q, R); + } + else + if (const SCEVSignExtendExpr *SExtDivisor = + dyn_cast(Divisor)) { + divide(SE, CEDividend->getOperand(), SExtDivisor->getOperand(), Q, R); + } + else + divide(SE, CEDividend->getOperand(), Divisor, Q, R); + } + + void visitTruncateExpr(const SCEVTruncateExpr *Dividend) { + const SCEV *Q, *R; + treatCastDivisor(Dividend, &Q, &R); + + Quotient = SE.getTruncateExpr(Q, Dividend->getType()); + Remainder = SE.getTruncateExpr(R, Dividend->getType()); + } + + void visitZeroExtendExpr(const SCEVZeroExtendExpr *Dividend) { + const SCEV *Q, *R; + treatCastDivisor(Dividend, &Q, &R); + + Quotient = SE.getZeroExtendExpr(Q, Dividend->getType()); + Remainder = SE.getZeroExtendExpr(R, Dividend->getType()); + } + + void visitSignExtendExpr(const SCEVSignExtendExpr *Dividend) { + const SCEV *Q, *R; + treatCastDivisor(Dividend, &Q, &R); + + Quotient = SE.getSignExtendExpr(Q, Dividend->getType()); + Remainder = SE.getSignExtendExpr(R, Dividend->getType()); + } + + void visitConstant(const SCEVConstant *Dividend) { + if (const SCEVConstant *D = dyn_cast(Divisor)) { + APInt DividendVal = Dividend->getAPInt(); + APInt DivisorVal = D->getAPInt(); + uint32_t DividendBW = DividendVal.getBitWidth(); + uint32_t DivisorBW = DivisorVal.getBitWidth(); + + if (DividendBW > DivisorBW) + DivisorVal = DivisorVal.sext(DividendBW); + else if (DividendBW < DivisorBW) + DividendVal = DividendVal.sext(DivisorBW); + + APInt QuotientVal(DividendVal.getBitWidth(), 0); + APInt RemainderVal(DividendVal.getBitWidth(), 0); + APInt::sdivrem(DividendVal, DivisorVal, QuotientVal, RemainderVal); Quotient = SE.getConstant(QuotientVal); Remainder = SE.getConstant(RemainderVal); return; } } - void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { + void visitAddRecExpr(const SCEVAddRecExpr *Dividend) { const SCEV *StartQ, *StartR, *StepQ, *StepR; - if (!Numerator->isAffine()) - return cannotDivide(Numerator); - divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); - divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); - // Bail out if the types do not match. - Type *Ty = Denominator->getType(); - if (Ty != StartQ->getType() || Ty != StartR->getType() || - Ty != StepQ->getType() || Ty != StepR->getType()) - return cannotDivide(Numerator); - Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), - Numerator->getNoWrapFlags()); - Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), - Numerator->getNoWrapFlags()); - } - - void visitAddExpr(const SCEVAddExpr *Numerator) { + if (!Dividend->isAffine()) + return cannotDivide(Dividend); + divide(SE, Dividend->getStart(), Divisor, &StartQ, &StartR); + divide(SE, Dividend->getStepRecurrence(SE), Divisor, &StepQ, &StepR); + + assert(Dividend->getStart()->getType() == StartQ->getType() && + StartQ->getType() == StartR->getType() && + "Expected matching types"); + assert(Dividend->getStepRecurrence(SE)->getType() == StepQ->getType() && + StepQ->getType() == StepR->getType() && + "Expected matching types"); + Quotient = SE.getAddRecExpr(StartQ, StepQ, Dividend->getLoop(), + Dividend->getNoWrapFlags()); + Remainder = SE.getAddRecExpr(StartR, StepR, Dividend->getLoop(), + Dividend->getNoWrapFlags()); + } + + void visitAddExpr(const SCEVAddExpr *Dividend) { SmallVector Qs, Rs; - Type *Ty = Denominator->getType(); - for (const SCEV *Op : Numerator->operands()) { + Type *Ty = Dividend->getType(); + + for (const SCEV *Op : Dividend->operands()) { const SCEV *Q, *R; - divide(SE, Op, Denominator, &Q, &R); + divide(SE, Op, Divisor, &Q, &R); // Bail out if types do not match. if (Ty != Q->getType() || Ty != R->getType()) - return cannotDivide(Numerator); + return cannotDivide(Dividend); + + assert(Ty == Q->getType() && Ty == R->getType() && + "Expected matching types"); Qs.push_back(Q); Rs.push_back(R); @@ -973,24 +1020,24 @@ Remainder = SE.getAddExpr(Rs); } - void visitMulExpr(const SCEVMulExpr *Numerator) { + void visitMulExpr(const SCEVMulExpr *Dividend) { SmallVector Qs; - Type *Ty = Denominator->getType(); + Type *Ty = Divisor->getType(); - bool FoundDenominatorTerm = false; - for (const SCEV *Op : Numerator->operands()) { + bool FoundDivisorTerm = false; + for (const SCEV *Op : Dividend->operands()) { // Bail out if types do not match. if (Ty != Op->getType()) - return cannotDivide(Numerator); + return cannotDivide(Dividend); - if (FoundDenominatorTerm) { + if (FoundDivisorTerm) { Qs.push_back(Op); continue; } - // Check whether Denominator divides one of the product operands. + // Check whether Divisor divides one of the product operands. const SCEV *Q, *R; - divide(SE, Op, Denominator, &Q, &R); + divide(SE, Op, Divisor, &Q, &R); if (!R->isZero()) { Qs.push_back(Op); continue; @@ -998,13 +1045,13 @@ // Bail out if types do not match. if (Ty != Q->getType()) - return cannotDivide(Numerator); + return cannotDivide(Dividend); - FoundDenominatorTerm = true; + FoundDivisorTerm = true; Qs.push_back(Q); } - if (FoundDenominatorTerm) { + if (FoundDivisorTerm) { Remainder = Zero; if (Qs.size() == 1) Quotient = Qs[0]; @@ -1013,58 +1060,58 @@ return; } - if (!isa(Denominator)) - return cannotDivide(Numerator); + if (!isa(Divisor)) + return cannotDivide(Dividend); - // The Remainder is obtained by replacing Denominator by 0 in Numerator. + // The Remainder is obtained by replacing Divisor by 0 in Dividend. ValueToValueMap RewriteMap; - RewriteMap[cast(Denominator)->getValue()] = - cast(Zero)->getValue(); - Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + RewriteMap[cast(Divisor)->getValue()] = + cast(SE.getZero(Divisor->getType()))->getValue(); + Remainder = SCEVParameterRewriter::rewrite(Dividend, SE, RewriteMap, true); if (Remainder->isZero()) { - // The Quotient is obtained by replacing Denominator by 1 in Numerator. - RewriteMap[cast(Denominator)->getValue()] = - cast(One)->getValue(); + // The Quotient is obtained by replacing Divisor by 1 in Dividend. + RewriteMap[cast(Divisor)->getValue()] = + cast(SE.getOne(Divisor->getType()))->getValue(); Quotient = - SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + SCEVParameterRewriter::rewrite(Dividend, SE, RewriteMap, true); return; } - // Quotient is (Numerator - Remainder) divided by Denominator. + // Quotient is (Dividend - Remainder) divided by Divisor. const SCEV *Q, *R; - const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); + const SCEV *Diff = SE.getMinusSCEV(Dividend, Remainder); // This SCEV does not seem to simplify: fail the division here. - if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) - return cannotDivide(Numerator); - divide(SE, Diff, Denominator, &Q, &R); + if (sizeOfSCEV(Diff) > sizeOfSCEV(Dividend)) + return cannotDivide(Dividend); + divide(SE, Diff, Divisor, &Q, &R); if (R != Zero) - return cannotDivide(Numerator); + return cannotDivide(Dividend); Quotient = Q; } private: - SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, - const SCEV *Denominator) - : SE(S), Denominator(Denominator) { - Zero = SE.getZero(Denominator->getType()); - One = SE.getOne(Denominator->getType()); + SCEVDivision(ScalarEvolution &S, const SCEV *Dividend, + const SCEV *Divisor) + : SE(S), Divisor(Divisor) { + Zero = SE.getZero(Dividend->getType()); + One = SE.getOne(Dividend->getType()); - // We generally do not know how to divide Expr by Denominator. We + // We generally do not know how to divide Expr by Divisor. We // initialize the division to a "cannot divide" state to simplify the rest // of the code. - cannotDivide(Numerator); + cannotDivide(Dividend); } // Convenience function for giving up on the division. We set the quotient to - // be equal to zero and the remainder to be equal to the numerator. - void cannotDivide(const SCEV *Numerator) { + // be equal to zero and the remainder to be equal to the dividend. + void cannotDivide(const SCEV *Dividend) { Quotient = Zero; - Remainder = Numerator; + Remainder = Dividend; } ScalarEvolution &SE; - const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; + const SCEV *Divisor, *Quotient, *Remainder, *Zero, *One; }; } // end anonymous namespace @@ -3038,7 +3085,7 @@ if (const SCEVConstant *RHSC = dyn_cast(RHS)) { if (RHSC->getValue()->isOne()) return LHS; // X udiv 1 --> x - // If the denominator is zero, the result of the udiv is undefined. Don't + // If the divisor is zero, the result of the udiv is undefined. Don't // try to analyze it, because the resolution chosen here may differ from // the resolution chosen in other parts of the compiler. if (!RHSC->getValue()->isZero()) { @@ -9557,30 +9604,30 @@ if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) { // Rules for division. - // We are going to perform some comparisons with Denominator and its + // We are going to perform some comparisons with Divisor and its // derivative expressions. In general case, creating a SCEV for it may // lead to a complex analysis of the entire graph, and in particular it // can request trip count recalculation for the same loop. This would // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid // this, we only want to create SCEVs that are constants in this section. - // So we bail if Denominator is not a constant. + // So we bail if Divisor is not a constant. if (!isa(LR)) return false; - auto *Denominator = cast(getSCEV(LR)); + auto *Divisor = cast(getSCEV(LR)); - // We want to make sure that LHS = FoundLHS / Denominator. If it is so, - // then a SCEV for the numerator already exists and matches with FoundLHS. - auto *Numerator = getExistingSCEV(LL); - if (!Numerator || Numerator->getType() != FoundLHS->getType()) + // We want to make sure that LHS = FoundLHS / Divisor. If it is so, + // then a SCEV for the dividend already exists and matches with FoundLHS. + auto *Dividend = getExistingSCEV(LL); + if (!Dividend || Dividend->getType() != FoundLHS->getType()) return false; - // Make sure that the numerator matches with FoundLHS and the denominator + // Make sure that the dividend matches with FoundLHS and the divisor // is positive. - if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator)) + if (!HasSameValue(Dividend, FoundLHS) || !isKnownPositive(Divisor)) return false; - auto *DTy = Denominator->getType(); + auto *DTy = Divisor->getType(); auto *FRHSTy = FoundRHS->getType(); if (DTy->isPointerTy() != FRHSTy->isPointerTy()) // One of types is a pointer and another one is not. We cannot extend @@ -9590,29 +9637,29 @@ return false; // Given that: - // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0. + // FoundLHS > FoundRHS, LHS = FoundLHS / Divisor, Divisor > 0. auto *WTy = getWiderType(DTy, FRHSTy); - auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy); + auto *DivisorExt = getNoopOrSignExtend(Divisor, WTy); auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy); // Try to prove the following rule: - // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS). + // (FoundRHS > Divisor - 2) && (RHS <= 0) => (LHS > RHS). // For example, given that FoundLHS > 2. It means that FoundLHS is at - // least 3. If we divide it by Denominator < 4, we will have at least 1. - auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2)); + // least 3. If we divide it by Divisor < 4, we will have at least 1. + auto *DenomMinusTwo = getMinusSCEV(DivisorExt, getConstant(WTy, 2)); if (isKnownNonPositive(RHS) && IsSGTViaContext(FoundRHSExt, DenomMinusTwo)) return true; // Try to prove the following rule: - // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS). + // (FoundRHS > -1 - Divisor) && (RHS < 0) => (LHS > RHS). // For example, given that FoundLHS > -3. Then FoundLHS is at least -2. - // If we divide it by Denominator > 2, then: + // If we divide it by Divisor > 2, then: // 1. If FoundLHS is negative, then the result is 0. // 2. If FoundLHS is non-negative, then the result is non-negative. // Anyways, the result is non-negative. auto *MinusOne = getNegativeSCEV(getOne(WTy)); - auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt); + auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DivisorExt); if (isKnownNegative(RHS) && IsSGTViaContext(FoundRHSExt, NegDenomMinusOne)) return true; @@ -10361,6 +10408,18 @@ // Return the number of product terms in S. static inline int numberOfTerms(const SCEV *S) { + if (const SCEVTruncateExpr *TruncDivisor = dyn_cast(S)) { + return numberOfTerms(TruncDivisor->getOperand()); + } + + if (const SCEVZeroExtendExpr *ZExtDivisor = dyn_cast(S)) { + return numberOfTerms(ZExtDivisor->getOperand()); + } + + if (const SCEVSignExtendExpr *SExtDivisor = dyn_cast(S)) { + return numberOfTerms(SExtDivisor->getOperand()); + } + if (const SCEVMulExpr *Expr = dyn_cast(S)) return Expr->getNumOperands(); return 1;