diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h --- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -449,7 +449,7 @@ const Loop *getRelevantLoop(const SCEV *); Value *expandMinMaxExpr(const SCEVNAryExpr *S, Intrinsic::ID IntrinID, - Twine Name); + Twine Name, bool IsSequential = false); Value *visitConstant(const SCEVConstant *S) { return S->getValue(); } diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -1671,11 +1671,16 @@ } Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S, - Intrinsic::ID IntrinID, Twine Name) { + Intrinsic::ID IntrinID, Twine Name, + bool IsSequential) { Value *LHS = expand(S->getOperand(S->getNumOperands() - 1)); Type *Ty = LHS->getType(); + if (IsSequential) + LHS = Builder.CreateFreeze(LHS); for (int i = S->getNumOperands() - 2; i >= 0; --i) { Value *RHS = expandCodeForImpl(S->getOperand(i), Ty, false); + if (IsSequential && i != 0) + RHS = Builder.CreateFreeze(RHS); Value *Sel; if (Ty->isIntegerTy()) Sel = Builder.CreateIntrinsic(IntrinID, {Ty}, {LHS, RHS}, @@ -1707,21 +1712,7 @@ } Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) { - SmallVector Ops; - for (const SCEV *Op : S->operands()) - Ops.emplace_back(expand(Op)); - - Value *SaturationPoint = - MinMaxIntrinsic::getSaturationPoint(Intrinsic::umin, S->getType()); - - SmallVector OpIsZero; - for (Value *Op : ArrayRef(Ops).drop_back()) - OpIsZero.emplace_back(Builder.CreateICmpEQ(Op, SaturationPoint)); - - Value *AnyOpIsZero = Builder.CreateLogicalOr(OpIsZero); - - Value *NaiveUMin = expandMinMaxExpr(S, Intrinsic::umin, "umin"); - return Builder.CreateSelect(AnyOpIsZero, SaturationPoint, NaiveUMin); + return expandMinMaxExpr(S, Intrinsic::umin, "umin", /*IsSequential*/true); } Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, diff --git a/llvm/test/Transforms/IndVarSimplify/exit-count-select.ll b/llvm/test/Transforms/IndVarSimplify/exit-count-select.ll --- a/llvm/test/Transforms/IndVarSimplify/exit-count-select.ll +++ b/llvm/test/Transforms/IndVarSimplify/exit-count-select.ll @@ -4,14 +4,13 @@ define i32 @logical_and_2ops(i32 %n, i32 %m) { ; CHECK-LABEL: @logical_and_2ops( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[M:%.*]], i32 [[N:%.*]]) +; CHECK-NEXT: [[TMP0:%.*]] = freeze i32 [[M:%.*]] ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: ; CHECK-NEXT: br i1 false, label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[N]], 0 -; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[TMP0]], i32 0, i32 [[UMIN]] -; CHECK-NEXT: ret i32 [[TMP1]] +; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[TMP0]], i32 [[N:%.*]]) +; CHECK-NEXT: ret i32 [[UMIN]] ; entry: br label %loop @@ -29,14 +28,13 @@ define i32 @logical_or_2ops(i32 %n, i32 %m) { ; CHECK-LABEL: @logical_or_2ops( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[M:%.*]], i32 [[N:%.*]]) +; CHECK-NEXT: [[TMP0:%.*]] = freeze i32 [[M:%.*]] ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: ; CHECK-NEXT: br i1 true, label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: -; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[N]], 0 -; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[TMP0]], i32 0, i32 [[UMIN]] -; CHECK-NEXT: ret i32 [[TMP1]] +; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[TMP0]], i32 [[N:%.*]]) +; CHECK-NEXT: ret i32 [[UMIN]] ; entry: br label %loop @@ -54,17 +52,15 @@ define i32 @logical_and_3ops(i32 %n, i32 %m, i32 %k) { ; CHECK-LABEL: @logical_and_3ops( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[M:%.*]], 0 -; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[K:%.*]], i32 [[M]]) -; CHECK-NEXT: [[UMIN1:%.*]] = call i32 @llvm.umin.i32(i32 [[UMIN]], i32 [[N:%.*]]) +; CHECK-NEXT: [[TMP0:%.*]] = freeze i32 [[K:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = freeze i32 [[M:%.*]] +; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[TMP0]], i32 [[TMP1]]) ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: ; CHECK-NEXT: br i1 false, label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[N]], 0 -; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i1 true, i1 [[TMP0]] -; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], i32 0, i32 [[UMIN1]] -; CHECK-NEXT: ret i32 [[TMP3]] +; CHECK-NEXT: [[UMIN1:%.*]] = call i32 @llvm.umin.i32(i32 [[UMIN]], i32 [[N:%.*]]) +; CHECK-NEXT: ret i32 [[UMIN1]] ; entry: br label %loop @@ -84,17 +80,15 @@ define i32 @logical_or_3ops(i32 %n, i32 %m, i32 %k) { ; CHECK-LABEL: @logical_or_3ops( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[M:%.*]], 0 -; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[K:%.*]], i32 [[M]]) -; CHECK-NEXT: [[UMIN1:%.*]] = call i32 @llvm.umin.i32(i32 [[UMIN]], i32 [[N:%.*]]) +; CHECK-NEXT: [[TMP0:%.*]] = freeze i32 [[K:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = freeze i32 [[M:%.*]] +; CHECK-NEXT: [[UMIN:%.*]] = call i32 @llvm.umin.i32(i32 [[TMP0]], i32 [[TMP1]]) ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: ; CHECK-NEXT: br i1 true, label [[EXIT:%.*]], label [[LOOP]] ; CHECK: exit: -; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[N]], 0 -; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[TMP1]], i1 true, i1 [[TMP0]] -; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], i32 0, i32 [[UMIN1]] -; CHECK-NEXT: ret i32 [[TMP3]] +; CHECK-NEXT: [[UMIN1:%.*]] = call i32 @llvm.umin.i32(i32 [[UMIN]], i32 [[N:%.*]]) +; CHECK-NEXT: ret i32 [[UMIN1]] ; entry: br label %loop