diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -2894,7 +2894,8 @@ for (auto &I : *BB) { if (auto *SI = dyn_cast(&I)) { auto *Cond = SI->getCondition(); - if (!isa(Cond) && L.isLoopInvariant(Cond)) + // restrict to simple boolean selects + if (!isa(Cond) && L.isLoopInvariant(Cond) && Cond->getType()->isIntegerTy(1)) UnswitchCandidates.push_back({&I, {Cond}}); } else if (CollectGuards && isGuard(&I)) { auto *Cond = diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-select.ll b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-select.ll --- a/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-select.ll +++ b/llvm/test/Transforms/SimpleLoopUnswitch/nontrivial-unswitch-select.ll @@ -5,6 +5,7 @@ declare i1 @foo() declare i1 @bar(i32) +declare i32 @llvm.vector.reduce.add.v2i32(<2 x i32>) define i32 @basic(i32 %N, i1 %cond, i32 %select_input) { ; CHECK-LABEL: define i32 @basic @@ -72,6 +73,46 @@ ret i32 %res } +define i32 @basic_veccond(i32 %N, <2 x i1> %cond, <2 x i32> %select_input) { +; CHECK-LABEL: define i32 @basic_veccond +; CHECK-SAME: (i32 [[N:%.*]], <2 x i1> [[COND:%.*]], <2 x i32> [[SELECT_INPUT:%.*]]) { +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[FOR_COND:%.*]] +; CHECK: for.cond: +; CHECK-NEXT: [[RES:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[ADD:%.*]], [[FOR_BODY:%.*]] ] +; CHECK-NEXT: [[I:%.*]] = phi i32 [ 0, [[ENTRY]] ], [ [[INC:%.*]], [[FOR_BODY]] ] +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[I]], [[N]] +; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_CLEANUP:%.*]] +; CHECK: for.body: +; CHECK-NEXT: [[COND1:%.*]] = select <2 x i1> [[COND]], <2 x i32> [[SELECT_INPUT]], <2 x i32> +; CHECK-NEXT: [[VREDUCE:%.*]] = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> [[COND1]]) +; CHECK-NEXT: [[ADD]] = add nuw nsw i32 [[VREDUCE]], [[RES]] +; CHECK-NEXT: [[INC]] = add nuw nsw i32 [[I]], 1 +; CHECK-NEXT: br label [[FOR_COND]] +; CHECK: for.cond.cleanup: +; CHECK-NEXT: [[RES_LCSSA:%.*]] = phi i32 [ [[RES]], [[FOR_COND]] ] +; CHECK-NEXT: ret i32 [[RES_LCSSA]] +; +entry: + br label %for.cond + +for.cond: ; preds = %for.body, %entry + %res = phi i32 [ 0, %entry ], [ %add, %for.body ] + %i = phi i32 [ 0, %entry ], [ %inc, %for.body ] + %cmp = icmp slt i32 %i, %N + br i1 %cmp, label %for.body, label %for.cond.cleanup + +for.body: ; preds = %for.cond + %cond1 = select <2 x i1> %cond, <2 x i32> %select_input, <2 x i32> + %vreduce = call i32 @llvm.vector.reduce.add.v2i32(<2 x i32> %cond1) + %add = add nuw nsw i32 %vreduce, %res + %inc = add nuw nsw i32 %i, 1 + br label %for.cond + +for.cond.cleanup: ; preds = %for.cond + ret i32 %res +} + define i32 @select_phi_input(i32 %N, i1 %cond) { ; CHECK-LABEL: define i32 @select_phi_input ; CHECK-SAME: (i32 [[N:%.*]], i1 [[COND:%.*]]) {