Index: llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp =================================================================== --- llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -156,6 +156,10 @@ const SCEVAddRecExpr *&Index, const SCEV *&End); + static bool reassociateSubLHS(Loop *L, Value *VariantLHS, Value *InvariantRHS, + ICmpInst::Predicate Pred, ScalarEvolution &SE, + const SCEVAddRecExpr *&Index, const SCEV *&End); + public: const SCEV *getBegin() const { return Begin; } const SCEV *getStep() const { return Step; } @@ -281,6 +285,10 @@ if (parseIvAgaisntLimit(L, LHS, RHS, Pred, SE, Index, End)) return true; + if (reassociateSubLHS(L, LHS, RHS, Pred, SE, Index, End)) + return true; + + // TODO: support ReassociateAddLHS return false; } @@ -346,6 +354,80 @@ llvm_unreachable("default clause returns!"); } +// Try to parse range check in the form of "IV - Offset vs Limit" or "Offset - +// IV vs Limit" +bool InductiveRangeCheck::reassociateSubLHS( + Loop *L, Value *VariantLHS, Value *InvariantRHS, ICmpInst::Predicate Pred, + ScalarEvolution &SE, const SCEVAddRecExpr *&Index, const SCEV *&End) { + auto GetWideType = [](const SCEV *X) { + auto *Ty = cast(X->getType()); + return IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); + }; + + Value *LHS, *RHS; + if (!match(VariantLHS, m_Sub(m_Value(LHS), m_Value(RHS)))) + return false; + + const SCEV *IV = SE.getSCEV(LHS); + const SCEV *Offset = SE.getSCEV(RHS); + const SCEV *Limit = SE.getSCEV(InvariantRHS); + + // TODO: Can we remove this restriction? + if (!SE.willNotOverflow(Instruction::BinaryOps::Sub, ICmpInst::isSigned(Pred), + IV, Offset, cast(VariantLHS))) { + return false; + } + + bool OffsetSubtracted = false; + if (SE.isLoopInvariant(IV, L)) + // "Offset - IV vs Limit" + std::swap(IV, Offset); + else if (SE.isLoopInvariant(Offset, L)) + // "IV - Offset vs Limit" + OffsetSubtracted = true; + else + return false; + + const auto *AddRec = dyn_cast(IV); + if (!AddRec) + return false; + + // We are going to reassociate expression but the computations can overflow. + // To avoid it, let's use extended type for all operands and then add runtime + // check whether overflow happens or not + auto WideTy = GetWideType(AddRec); + const SCEV *Scale = SE.getOne(WideTy); + Limit = SE.getSignExtendExpr(Limit, WideTy); + Offset = SE.getSignExtendExpr(Offset, WideTy); + if (OffsetSubtracted) + Offset = SE.getNegativeSCEV(Offset); + else + Scale = SE.getNegativeSCEV(Scale); + + // Here we have "IV*Scale + Offset vs Limit" + + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) { + // "IV*Scale + Offset >= Limit" -> "IV*(-Scale) + (-Offset) <= (-Limit)" + Scale = SE.getNegativeSCEV(Scale); + Offset = SE.getNegativeSCEV(Offset); + Limit = SE.getNegativeSCEV(Limit); + Pred = CmpInst::getSwappedPredicate(Pred); + } + + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) { + Index = AddRec; + // "Expr <= Limit" -> "Expr < Limit + 1" + if (Pred == ICmpInst::ICMP_SLE) + Limit = SE.getAddExpr(Limit, SE.getOne(WideTy)); + // "IV*Scale + Offset < Limit" -> IV < (Limit - Offset)/Scale + // We multiply by Scale because it's 1 or -1 + assert(Scale->isOne() || Scale->isAllOnesValue()); + End = SE.getMulExpr(SE.getMinusSCEV(Limit, Offset), Scale); + return true; + } + return false; +} + void InductiveRangeCheck::extractRangeChecksFromCond( Loop *L, ScalarEvolution &SE, Use &ConditionUse, SmallVectorImpl &Checks, @@ -1590,11 +1672,22 @@ // if latch check is more narrow. auto *IVType = dyn_cast(IndVar->getType()); auto *RCType = dyn_cast(getBegin()->getType()); + auto *EndType = dyn_cast(getEnd()->getType()); // Do not work with pointer types. if (!IVType || !RCType) return std::nullopt; if (IVType->getBitWidth() > RCType->getBitWidth()) return std::nullopt; + + if (EndType->getBitWidth() > RCType->getBitWidth()) { + assert(EndType->getBitWidth() == RCType->getBitWidth() * 2); + // End is computed with extended type but will be truncated to a narrow one + // type of range check. Therefore we need a check that the result will not + // overflow in terms of narrow type. + // TODO: Support runtime overflow check for End + return std::nullopt; + } + // IndVar is of the form "A + B * I" (where "I" is the canonical induction // variable, that may or may not exist as a real llvm::Value in the loop) and // this inductive range check is a range check on the "C + D * I" ("C" is