Index: llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp =================================================================== --- llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -150,6 +150,10 @@ SmallVectorImpl &Checks, SmallPtrSetImpl &Visited); + static bool ReassociateSubLHS(Loop *L, const ICmpInst *ICI, Value *VariantLHS, + Value *InvariantRHS, ScalarEvolution &SE, + const SCEV *&Index, const SCEV *&End); + public: const SCEV *getBegin() const { return Begin; } const SCEV *getStep() const { return Step; } @@ -272,6 +276,11 @@ // Both LHS and RHS are loop variant return false; + if (ReassociateSubLHS(L, ICI, LHS, RHS, SE, Index, End)) + return true; + + // TODO: support ReassociateAddLHS + switch (Pred) { default: return false; @@ -390,6 +399,81 @@ Checks, Visited); } +// Try to parse range check in the form of "IV - Offset vs Limit" or "Offset - +// IV vs Limit" +bool InductiveRangeCheck::ReassociateSubLHS( + Loop *L, const ICmpInst *ICI, Value *VariantLHS, Value *InvariantRHS, + ScalarEvolution &SE, const SCEV *&Index, const SCEV *&End) { + auto GetWideType = [](const SCEV *X) { + auto *Ty = cast(X->getType()); + return IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); + }; + + Value *LHS, *RHS; + if (!match(VariantLHS, m_Sub(m_Value(LHS), m_Value(RHS)))) + return false; + + const SCEV *IV = SE.getSCEV(LHS); + const SCEV *Offset = SE.getSCEV(RHS); + const SCEV *Limit = SE.getSCEV(InvariantRHS); + ICmpInst::Predicate Pred = ICI->getPredicate(); + + // TODO: Can we remove this restriction? + if (!SE.willNotOverflow(Instruction::BinaryOps::Sub, ICmpInst::isSigned(Pred), + IV, Offset, ICI)) { + return false; + } + + bool OffsetSubtracted = false; + if (SE.isLoopInvariant(IV, L)) + // "Offset - IV vs Limit" + std::swap(IV, Offset); + else if (SE.isLoopInvariant(Offset, L)) + // "IV - Offset vs Limit" + OffsetSubtracted = true; + else + return false; + + // We support Scales equal to 1 or -1 + if (!isa(IV)) + return false; + + // We are going to reassociate expression but the computations can overflow. + // To avoid it, let's use extended type for all operands and then add runtime + // check whether overflow happens or not + auto WideTy = GetWideType(IV); + const SCEV *Scale = SE.getOne(WideTy); + Limit = SE.getSignExtendExpr(Limit, WideTy); + Offset = SE.getSignExtendExpr(Offset, WideTy); + if (OffsetSubtracted) + Offset = SE.getNegativeSCEV(Offset); + else + Scale = SE.getNegativeSCEV(Scale); + + // Here we have "IV*Scale + Offset vs Limit" + + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) { + // "IV*Scale + Offset >= Limit" -> "IV*(-Scale) + (-Offset) <= (-Limit)" + Scale = SE.getNegativeSCEV(Scale); + Offset = SE.getNegativeSCEV(Offset); + Limit = SE.getNegativeSCEV(Limit); + Pred = CmpInst::getSwappedPredicate(Pred); + } + + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) { + Index = IV; + // "Expr <= Limit" -> "Expr < Limit + 1" + if (Pred == ICmpInst::ICMP_SLE) + Limit = SE.getAddExpr(Limit, SE.getOne(WideTy)); + // "IV*Scale + Offset < Limit" -> IV < (Limit - Offset)/Scale + // We multiply by Scale because it's 1 or -1 + assert(Scale->isOne() || Scale->isAllOnesValue()); + End = SE.getMulExpr(SE.getMinusSCEV(Limit, Offset), Scale); + return true; + } + return false; +} + // Add metadata to the loop L to disable loop optimizations. Callers need to // confirm that optimizing loop L is not beneficial. static void DisableAllLoopOptsOnLoop(Loop &L) { @@ -1567,11 +1651,22 @@ // if latch check is more narrow. auto *IVType = dyn_cast(IndVar->getType()); auto *RCType = dyn_cast(getBegin()->getType()); + auto *EndType = dyn_cast(getEnd()->getType()); // Do not work with pointer types. if (!IVType || !RCType) return std::nullopt; if (IVType->getBitWidth() > RCType->getBitWidth()) return std::nullopt; + + if (EndType->getBitWidth() > RCType->getBitWidth()) { + assert(EndType->getBitWidth() == RCType->getBitWidth() * 2); + // End is computed with extended type but will be truncated to a narrow one + // type of range check. Therefore we need a check that the result will not + // overflow in terms of narrow type. + // TODO: Support runtime overflow check for End + return std::nullopt; + } + // IndVar is of the form "A + B * I" (where "I" is the canonical induction // variable, that may or may not exist as a real llvm::Value in the loop) and // this inductive range check is a range check on the "C + D * I" ("C" is