diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -69,6 +69,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" @@ -5997,6 +5998,47 @@ return getUnknown(I); } +static void factorCommonOperandsOutOfAddExprs( + ScalarEvolution *SE, const SCEV *LHSExpr, const SCEV *RHSExpr, + const SCEV *&CommonExpr, const SCEV *&NewLHSExpr, const SCEV *&NewRHSExpr) { + Type *Ty = LHSExpr->getType(); + assert(RHSExpr->getType() == Ty && "Unexpected types!"); + + auto getOps = [](const SCEV **Expr) -> SCEVNAryExpr::op_range { + if (auto *AddExpr = dyn_cast(*Expr)) + return AddExpr->operands(); + return make_range(Expr, Expr + 1); + }; + + struct NumOccurrences { + unsigned LHS = 0, RHS = 0; + }; + SmallDenseMap Ops; + for (const SCEV *Op : getOps(&LHSExpr)) + ++Ops[Op].LHS; + for (const SCEV *Op : getOps(&RHSExpr)) + ++Ops[Op].RHS; + + SmallVector CommonOps, NewLHSOps, NewRHSOps; + for (auto I : Ops) { + const SCEV *Op = I.first; + const NumOccurrences Times = I.second; + + unsigned NumCommonOccurrences = std::min(Times.LHS, Times.RHS); + + CommonOps.emplace_back( + SE->getMulExpr(SE->getConstant(Ty, NumCommonOccurrences), Op)); + NewLHSOps.emplace_back(SE->getMulExpr( + SE->getConstant(Ty, Times.LHS - NumCommonOccurrences), Op)); + NewRHSOps.emplace_back(SE->getMulExpr( + SE->getConstant(Ty, Times.RHS - NumCommonOccurrences), Op)); + } + + CommonExpr = SE->getAddExpr(CommonOps); + NewLHSExpr = SE->getAddExpr(NewLHSOps); + NewRHSExpr = SE->getAddExpr(NewRHSOps); +} + static Optional createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr) { @@ -6005,6 +6047,10 @@ TrueExpr->getType()->isIntegerTy() && "Unexpected operands of a select."); + // i1 cond ? (iN base + iN X + iN C0) : (iN base + iN C1) + // --> + // iN base + (i1 cond ? (iN X + iN C0) : iN C1) + // i1 cond ? iN x : iN C --> C + (i1 cond ? (iN x - iN C) : iN 0) // --> C + (umin_seq (sext cond), x - C) // @@ -6012,10 +6058,16 @@ // --> C + (i1 ~cond ? (iN x - iN C) : iN 0) // --> C + (umin_seq (sext ~cond), x - C) - // FIXME: while we can't legally model the case where both of the hands - // are fully variable, we only require that the *difference* is constant. - if (!isa(TrueExpr) && !isa(FalseExpr)) - return None; + const SCEV *Base = SE->getZero(TrueExpr->getType()); + + // We require that at least one hand of the select be a constant. + if (!isa(TrueExpr) && !isa(FalseExpr)) { + factorCommonOperandsOutOfAddExprs(SE, TrueExpr, FalseExpr, Base, TrueExpr, + FalseExpr); + + if (!isa(TrueExpr) && !isa(FalseExpr)) + return None; + } const SCEV *X, *C; if (SE->getTypeSizeInBits(TrueExpr->getType()) > @@ -6029,8 +6081,9 @@ X = TrueExpr; C = FalseExpr; } - return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C), - /*Sequential=*/true)); + return SE->getAddExpr( + Base, C, + SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C), /*Sequential=*/true)); } static Optional createNodeForSelectViaUMinSeq(ScalarEvolution *SE, diff --git a/llvm/test/Analysis/ScalarEvolution/pointer-select.ll b/llvm/test/Analysis/ScalarEvolution/pointer-select.ll --- a/llvm/test/Analysis/ScalarEvolution/pointer-select.ll +++ b/llvm/test/Analysis/ScalarEvolution/pointer-select.ll @@ -176,7 +176,7 @@ ; CHECK-NEXT: %false_ptr = getelementptr i8, ptr %obj, i64 24 ; CHECK-NEXT: --> (24 + %base_offset + %obj.base) U: full-set S: full-set ; CHECK-NEXT: %r = select i1 %cond, ptr %true_ptr, ptr %false_ptr -; CHECK-NEXT: --> %r U: full-set S: full-set +; CHECK-NEXT: --> (42 + (-18 umin (-1 + (-1 * (sext i1 %cond to i64)))) + %base_offset + %obj.base) U: full-set S: full-set ; CHECK-NEXT: Determining loop execution counts for: @pointer_select_same_object_with_variable_base_offset__constant_offsets ; %obj = getelementptr i8, ptr %obj.base, i64 %base_offset @@ -216,7 +216,7 @@ ; CHECK-NEXT: %false_ptr = getelementptr i8, ptr %obj, i64 %false_off ; CHECK-NEXT: --> (%base_offset + %false_off + %obj.base) U: full-set S: full-set ; CHECK-NEXT: %r = select i1 %cond, ptr %true_ptr, ptr %false_ptr -; CHECK-NEXT: --> %r U: full-set S: full-set +; CHECK-NEXT: --> (42 + ((-1 + (-1 * (sext i1 %cond to i64))) umin_seq (-42 + %false_off)) + %base_offset + %obj.base) U: full-set S: full-set ; CHECK-NEXT: Determining loop execution counts for: @pointer_select_same_object_with_variable_base_offset__constant_offset_vs_variable_offset ; %obj = getelementptr i8, ptr %obj.base, i64 %base_offset @@ -236,7 +236,7 @@ ; CHECK-NEXT: %false_ptr = getelementptr i8, ptr %obj, i64 42 ; CHECK-NEXT: --> (42 + %base_offset + %obj.base) U: full-set S: full-set ; CHECK-NEXT: %r = select i1 %cond, ptr %true_ptr, ptr %false_ptr -; CHECK-NEXT: --> %r U: full-set S: full-set +; CHECK-NEXT: --> (42 + ((sext i1 %cond to i64) umin_seq (-42 + %true_off)) + %base_offset + %obj.base) U: full-set S: full-set ; CHECK-NEXT: Determining loop execution counts for: @pointer_select_same_object_with_variable_base_offset__variable_offset_vs_constant_offset ; %obj = getelementptr i8, ptr %obj.base, i64 %base_offset