Index: lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineSelect.cpp +++ lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -54,6 +54,43 @@ return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); } +/// Fold +/// %A = icmp eq/ne i8 %x, 0 +/// %B = op i8 %x, %z +/// %C = select i1 %A, i8 %B, i8 %y +/// To +/// %C = select i1 %A, i8 %z, i8 %y +/// OP: binop with an identity constant +/// TODO: support for non-commutative and FP opcodes +static Instruction * +foldSelectInstWithBinaryOp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) { + Value *Cond = Sel.getCondition(); + Value *TrueVal = Sel.getTrueValue(); + Value *FalseVal = Sel.getFalseValue(); + + Value *X, *Z; + Constant *C; + CmpInst::Predicate Pred; + if (match(Cond, m_ICmp(Pred, m_Value(X), m_Constant(C)))) { + if (Pred == ICmpInst::ICMP_EQ) { + if (match(TrueVal, m_c_BinOp(m_Specific(X), m_Value(Z))) && + ConstantExpr::getBinOpIdentity( + cast(TrueVal)->getOpcode(), X->getType()) == C) { + Sel.setTrueValue(Z); + return &Sel; + } + } else if (Pred == ICmpInst::ICMP_NE) { + if (match(FalseVal, m_c_BinOp(m_Specific(X), m_Value(Z))) && + ConstantExpr::getBinOpIdentity( + cast(FalseVal)->getOpcode(), X->getType()) == C) { + Sel.setFalseValue(Z); + return &Sel; + } + } + } + return nullptr; +} + /// This folds: /// select (icmp eq (and X, C1)), TC, FC /// iff C1 is a power 2 and the difference between TC and FC is a power-of-2. @@ -1961,5 +1998,8 @@ if (Instruction *Select = foldSelectCmpXchg(SI)) return Select; + if (Instruction *Select = foldSelectInstWithBinaryOp(SI, Builder)) + return Select; + return nullptr; -} +} Index: test/Transforms/InstCombine/select-binop-icmp.ll =================================================================== --- test/Transforms/InstCombine/select-binop-icmp.ll +++ test/Transforms/InstCombine/select-binop-icmp.ll @@ -4,8 +4,7 @@ define i32 @select_xor_icmp(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: @select_xor_icmp( ; CHECK-NEXT: [[A:%.*]] = icmp eq i32 [[X:%.*]], 0 -; CHECK-NEXT: [[B:%.*]] = xor i32 [[X]], [[Z:%.*]] -; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[B]], i32 [[Y:%.*]] +; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[Z:%.*]], i32 [[Y:%.*]] ; CHECK-NEXT: ret i32 [[C]] ; %A = icmp eq i32 %x, 0 @@ -14,11 +13,34 @@ ret i32 %C } +define i32 @select_xor_icmp2(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: @select_xor_icmp2( +; CHECK-NEXT: [[A:%.*]] = icmp eq i32 [[X:%.*]], 0 +; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[Z:%.*]], i32 [[Y:%.*]] +; CHECK-NEXT: ret i32 [[C]] +; + %A = icmp ne i32 %x, 0 + %B = xor i32 %x, %z + %C = select i1 %A, i32 %y, i32 %B + ret i32 %C +} + +define i32 @select_xor_icmp_meta(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: @select_xor_icmp_meta( +; CHECK-NEXT: [[A:%.*]] = icmp eq i32 [[X:%.*]], 0 +; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[Z:%.*]], i32 [[Y:%.*]], !prof !0 +; CHECK-NEXT: ret i32 [[C]] +; + %A = icmp eq i32 %x, 0 + %B = xor i32 %x, %z + %C = select i1 %A, i32 %B, i32 %y, !prof !0 + ret i32 %C +} + define i32 @select_mul_icmp(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: @select_mul_icmp( ; CHECK-NEXT: [[A:%.*]] = icmp eq i32 [[X:%.*]], 1 -; CHECK-NEXT: [[B:%.*]] = mul i32 [[X]], [[Z:%.*]] -; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[B]], i32 [[Y:%.*]] +; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[Z:%.*]], i32 [[Y:%.*]] ; CHECK-NEXT: ret i32 [[C]] ; %A = icmp eq i32 %x, 1 @@ -30,8 +52,7 @@ define i32 @select_add_icmp(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: @select_add_icmp( ; CHECK-NEXT: [[A:%.*]] = icmp eq i32 [[X:%.*]], 0 -; CHECK-NEXT: [[B:%.*]] = add i32 [[X]], [[Z:%.*]] -; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[B]], i32 [[Y:%.*]] +; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[Z:%.*]], i32 [[Y:%.*]] ; CHECK-NEXT: ret i32 [[C]] ; %A = icmp eq i32 %x, 0 @@ -43,8 +64,7 @@ define i32 @select_or_icmp(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: @select_or_icmp( ; CHECK-NEXT: [[A:%.*]] = icmp eq i32 [[X:%.*]], 0 -; CHECK-NEXT: [[B:%.*]] = or i32 [[X]], [[Z:%.*]] -; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[B]], i32 [[Y:%.*]] +; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[Z:%.*]], i32 [[Y:%.*]] ; CHECK-NEXT: ret i32 [[C]] ; %A = icmp eq i32 %x, 0 @@ -56,8 +76,7 @@ define i32 @select_and_icmp(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: @select_and_icmp( ; CHECK-NEXT: [[A:%.*]] = icmp eq i32 [[X:%.*]], -1 -; CHECK-NEXT: [[B:%.*]] = and i32 [[X]], [[Z:%.*]] -; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[B]], i32 [[Y:%.*]] +; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[Z:%.*]], i32 [[Y:%.*]] ; CHECK-NEXT: ret i32 [[C]] ; %A = icmp eq i32 %x, -1 @@ -69,8 +88,7 @@ define <2 x i8> @select_xor_icmp_vec(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z) { ; CHECK-LABEL: @select_xor_icmp_vec( ; CHECK-NEXT: [[A:%.*]] = icmp eq <2 x i8> [[X:%.*]], zeroinitializer -; CHECK-NEXT: [[B:%.*]] = xor <2 x i8> [[X]], [[Z:%.*]] -; CHECK-NEXT: [[C:%.*]] = select <2 x i1> [[A]], <2 x i8> [[B]], <2 x i8> [[Y:%.*]] +; CHECK-NEXT: [[C:%.*]] = select <2 x i1> [[A]], <2 x i8> [[Z:%.*]], <2 x i8> [[Y:%.*]] ; CHECK-NEXT: ret <2 x i8> [[C]] ; %A = icmp eq <2 x i8> %x, @@ -81,22 +99,20 @@ define <2 x i8> @select_xor_icmp_vec2(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z) { ; CHECK-LABEL: @select_xor_icmp_vec2( -; CHECK-NEXT: [[A:%.*]] = icmp eq <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[B:%.*]] = xor <2 x i8> [[X]], [[Z:%.*]] -; CHECK-NEXT: [[C:%.*]] = select <2 x i1> [[A]], <2 x i8> [[B]], <2 x i8> [[Y:%.*]] +; CHECK-NEXT: [[A:%.*]] = icmp eq <2 x i8> [[X:%.*]], zeroinitializer +; CHECK-NEXT: [[C:%.*]] = select <2 x i1> [[A]], <2 x i8> [[Z:%.*]], <2 x i8> [[Y:%.*]] ; CHECK-NEXT: ret <2 x i8> [[C]] ; - %A = icmp eq <2 x i8> %x, + %A = icmp ne <2 x i8> %x, %B = xor <2 x i8> %x, %z - %C = select <2 x i1> %A, <2 x i8> %B, <2 x i8> %y + %C = select <2 x i1> %A, <2 x i8> %y, <2 x i8> %B ret <2 x i8> %C } define i32 @select_xor_inv_icmp(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: @select_xor_inv_icmp( ; CHECK-NEXT: [[A:%.*]] = icmp eq i32 [[X:%.*]], 0 -; CHECK-NEXT: [[B:%.*]] = xor i32 [[Z:%.*]], [[X]] -; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[B]], i32 [[Y:%.*]] +; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[Z:%.*]], i32 [[Y:%.*]] ; CHECK-NEXT: ret i32 [[C]] ; %A = icmp eq i32 %x, 0 @@ -105,6 +121,18 @@ ret i32 %C } +define i32 @select_xor_inv_icmp2(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: @select_xor_inv_icmp2( +; CHECK-NEXT: [[A:%.*]] = icmp eq i32 [[X:%.*]], 0 +; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[Z:%.*]], i32 [[Y:%.*]] +; CHECK-NEXT: ret i32 [[C]] +; + %A = icmp ne i32 %x, 0 + %B = xor i32 %x, %z + %C = select i1 %A, i32 %y, i32 %B + ret i32 %C +} + ; Negative tests define i32 @select_xor_icmp_bad_1(i32 %x, i32 %y, i32 %z, i32 %k) { ; CHECK-LABEL: @select_xor_icmp_bad_1( @@ -171,6 +199,19 @@ ret i32 %C } +define i32 @select_xor_icmp_bad_6(i32 %x, i32 %y, i32 %z) { +; CHECK-LABEL: @select_xor_icmp_bad_6( +; CHECK-NEXT: [[A:%.*]] = icmp eq i32 [[X:%.*]], 1 +; CHECK-NEXT: [[B:%.*]] = xor i32 [[X]], [[Z:%.*]] +; CHECK-NEXT: [[C:%.*]] = select i1 [[A]], i32 [[B]], i32 [[Y:%.*]] +; CHECK-NEXT: ret i32 [[C]] +; + %A = icmp ne i32 %x, 1 + %B = xor i32 %x, %z + %C = select i1 %A, i32 %y, i32 %B + ret i32 %C +} + define <2 x i8> @select_xor_icmp_vec_bad(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z) { ; CHECK-LABEL: @select_xor_icmp_vec_bad( ; CHECK-NEXT: [[A:%.*]] = icmp eq <2 x i8> [[X:%.*]], @@ -184,6 +225,19 @@ ret <2 x i8> %C } +define <2 x i8> @select_xor_icmp_vec_bad_2(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z) { +; CHECK-LABEL: @select_xor_icmp_vec_bad_2( +; CHECK-NEXT: [[A:%.*]] = icmp eq <2 x i8> [[X:%.*]], +; CHECK-NEXT: [[B:%.*]] = xor <2 x i8> [[X]], [[Z:%.*]] +; CHECK-NEXT: [[C:%.*]] = select <2 x i1> [[A]], <2 x i8> [[B]], <2 x i8> [[Y:%.*]] +; CHECK-NEXT: ret <2 x i8> [[C]] +; + %A = icmp eq <2 x i8> %x, + %B = xor <2 x i8> %x, %z + %C = select <2 x i1> %A, <2 x i8> %B, <2 x i8> %y + ret <2 x i8> %C +} + define i32 @select_mul_icmp_bad(i32 %x, i32 %y, i32 %z, i32 %k) { ; CHECK-LABEL: @select_mul_icmp_bad( ; CHECK-NEXT: [[A:%.*]] = icmp eq i32 [[X:%.*]], 3 @@ -328,3 +382,5 @@ %C = select i1 %A, float %B, float %y ret float %C } + +!0 = !{!"branch_weights", i32 2, i32 10}