diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -6577,6 +6577,37 @@ if (Instruction *NI = foldSelectICmp(I.getSwappedPredicate(), SI, Op0, I)) return NI; + // In case of a comparison with two select instructions having the same + // condition, check whether one of the resulting branches can be simplified. + // If so, just compare the other branch and select the appropriate result. + // For example: + // %tmp1 = select i1 %cmp, i32 %y, i32 %x + // %tmp2 = select i1 %cmp, i32 %z, i32 %x + // %cmp2 = icmp slt i32 %tmp2, %tmp1 + // The icmp will result false for the false value of selects and the result + // will depend upon the comparison of true values of selects if %cmp is + // true. Thus, transform this into: + // %cmp = icmp slt i32 %y, %z + // %sel = select i1 %cond, i1 %cmp, i1 false + // This handles similar cases to transform. + { + Value *Cond, *A, *B, *C, *D; + if (match(Op0, m_Select(m_Value(Cond), m_Value(A), m_Value(B))) && + match(Op1, m_Select(m_Specific(Cond), m_Value(C), m_Value(D))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + // Check whether comparison of TrueValues can be simplified + if (Value *Res = simplifyICmpInst(Pred, A, C, SQ)) { + Value *NewICMP = Builder.CreateICmp(Pred, B, D); + return SelectInst::Create(Cond, Res, NewICMP); + } + // Check whether comparison of FalseValues can be simplified + if (Value *Res = simplifyICmpInst(Pred, B, D, SQ)) { + Value *NewICMP = Builder.CreateICmp(Pred, A, C); + return SelectInst::Create(Cond, NewICMP, Res); + } + } + } + // Try to optimize equality comparisons against alloca-based pointers. if (Op0->getType()->isPointerTy() && I.isEquality()) { assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?"); diff --git a/llvm/test/Transforms/InstCombine/icmp-with-selects.ll b/llvm/test/Transforms/InstCombine/icmp-with-selects.ll --- a/llvm/test/Transforms/InstCombine/icmp-with-selects.ll +++ b/llvm/test/Transforms/InstCombine/icmp-with-selects.ll @@ -7,10 +7,7 @@ ; CHECK-LABEL: define i1 @both_sides_fold_slt ; CHECK-SAME: (i32 [[PARAM:%.*]], i1 [[COND:%.*]]) { ; CHECK-NEXT: entry: -; CHECK-NEXT: [[COND1:%.*]] = select i1 [[COND]], i32 1, i32 [[PARAM]] -; CHECK-NEXT: [[COND6:%.*]] = select i1 [[COND]], i32 9, i32 [[PARAM]] -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[COND6]], [[COND1]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: ret i1 false ; entry: %cond1 = select i1 %cond, i32 1, i32 %param @@ -23,10 +20,8 @@ ; CHECK-LABEL: define i1 @both_sides_fold_eq ; CHECK-SAME: (i32 [[PARAM:%.*]], i1 [[COND:%.*]]) { ; CHECK-NEXT: entry: -; CHECK-NEXT: [[COND1:%.*]] = select i1 [[COND]], i32 1, i32 [[PARAM]] -; CHECK-NEXT: [[COND6:%.*]] = select i1 [[COND]], i32 9, i32 [[PARAM]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[COND6]], [[COND1]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: [[NOT_COND:%.*]] = xor i1 [[COND]], true +; CHECK-NEXT: ret i1 [[NOT_COND]] ; entry: %cond1 = select i1 %cond, i32 1, i32 %param @@ -39,9 +34,8 @@ ; CHECK-LABEL: define i1 @one_side_fold_slt ; CHECK-SAME: (i32 [[VAL1:%.*]], i32 [[VAL2:%.*]], i32 [[PARAM:%.*]], i1 [[COND:%.*]]) { ; CHECK-NEXT: entry: -; CHECK-NEXT: [[COND1:%.*]] = select i1 [[COND]], i32 [[VAL1]], i32 [[PARAM]] -; CHECK-NEXT: [[COND6:%.*]] = select i1 [[COND]], i32 [[VAL2]], i32 [[PARAM]] -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[COND6]], [[COND1]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp slt i32 [[VAL2]], [[VAL1]] +; CHECK-NEXT: [[CMP:%.*]] = select i1 [[COND]], i1 [[TMP0]], i1 false ; CHECK-NEXT: ret i1 [[CMP]] ; entry: @@ -55,9 +49,9 @@ ; CHECK-LABEL: define i1 @one_side_fold_sgt ; CHECK-SAME: (i32 [[VAL1:%.*]], i32 [[VAL2:%.*]], i32 [[PARAM:%.*]], i1 [[COND:%.*]]) { ; CHECK-NEXT: entry: -; CHECK-NEXT: [[COND1:%.*]] = select i1 [[COND]], i32 [[PARAM]], i32 [[VAL1]] -; CHECK-NEXT: [[COND6:%.*]] = select i1 [[COND]], i32 [[PARAM]], i32 [[VAL2]] -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[COND6]], [[COND1]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i32 [[VAL2]], [[VAL1]] +; CHECK-NEXT: [[NOT_COND:%.*]] = xor i1 [[COND]], true +; CHECK-NEXT: [[CMP:%.*]] = select i1 [[NOT_COND]], i1 [[TMP0]], i1 false ; CHECK-NEXT: ret i1 [[CMP]] ; entry: @@ -71,9 +65,9 @@ ; CHECK-LABEL: define i1 @one_side_fold_eq ; CHECK-SAME: (i32 [[VAL1:%.*]], i32 [[VAL2:%.*]], i32 [[PARAM:%.*]], i1 [[COND:%.*]]) { ; CHECK-NEXT: entry: -; CHECK-NEXT: [[COND1:%.*]] = select i1 [[COND]], i32 [[VAL1]], i32 [[PARAM]] -; CHECK-NEXT: [[COND6:%.*]] = select i1 [[COND]], i32 [[VAL2]], i32 [[PARAM]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[COND6]], [[COND1]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[VAL2]], [[VAL1]] +; CHECK-NEXT: [[NOT_COND:%.*]] = xor i1 [[COND]], true +; CHECK-NEXT: [[CMP:%.*]] = select i1 [[NOT_COND]], i1 true, i1 [[TMP0]] ; CHECK-NEXT: ret i1 [[CMP]] ; entry: @@ -120,9 +114,9 @@ ; CHECK-SAME: (i32 [[VAL1:%.*]], i32 [[VAL2:%.*]], i32 [[PARAM:%.*]], i1 [[COND:%.*]]) { ; CHECK-NEXT: entry: ; CHECK-NEXT: [[COND1:%.*]] = select i1 [[COND]], i32 [[VAL1]], i32 [[PARAM]] -; CHECK-NEXT: [[COND6:%.*]] = select i1 [[COND]], i32 [[VAL2]], i32 [[PARAM]] ; CHECK-NEXT: call void @use(i32 [[COND1]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[COND6]], [[COND1]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp slt i32 [[VAL2]], [[VAL1]] +; CHECK-NEXT: [[CMP:%.*]] = select i1 [[COND]], i1 [[TMP0]], i1 false ; CHECK-NEXT: ret i1 [[CMP]] ; entry: @@ -155,9 +149,8 @@ ; CHECK-LABEL: define <4 x i1> @fold_vector_ops ; CHECK-SAME: (<4 x i32> [[VAL1:%.*]], <4 x i32> [[VAL2:%.*]], <4 x i32> [[PARAM:%.*]], i1 [[COND:%.*]]) { ; CHECK-NEXT: entry: -; CHECK-NEXT: [[COND1:%.*]] = select i1 [[COND]], <4 x i32> [[VAL1]], <4 x i32> [[PARAM]] -; CHECK-NEXT: [[COND6:%.*]] = select i1 [[COND]], <4 x i32> [[VAL2]], <4 x i32> [[PARAM]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq <4 x i32> [[COND6]], [[COND1]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp eq <4 x i32> [[VAL2]], [[VAL1]] +; CHECK-NEXT: [[CMP:%.*]] = select i1 [[COND]], <4 x i1> [[TMP0]], <4 x i1> ; CHECK-NEXT: ret <4 x i1> [[CMP]] ; entry: @@ -171,9 +164,8 @@ ; CHECK-LABEL: define <8 x i1> @fold_vector_cond_ops ; CHECK-SAME: (<8 x i32> [[VAL1:%.*]], <8 x i32> [[VAL2:%.*]], <8 x i32> [[PARAM:%.*]], <8 x i1> [[COND:%.*]]) { ; CHECK-NEXT: entry: -; CHECK-NEXT: [[COND1:%.*]] = select <8 x i1> [[COND]], <8 x i32> [[VAL1]], <8 x i32> [[PARAM]] -; CHECK-NEXT: [[COND6:%.*]] = select <8 x i1> [[COND]], <8 x i32> [[VAL2]], <8 x i32> [[PARAM]] -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt <8 x i32> [[COND6]], [[COND1]] +; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt <8 x i32> [[VAL2]], [[VAL1]] +; CHECK-NEXT: [[CMP:%.*]] = select <8 x i1> [[COND]], <8 x i1> [[TMP0]], <8 x i1> zeroinitializer ; CHECK-NEXT: ret <8 x i1> [[CMP]] ; entry: