diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2600,7 +2600,9 @@ /// We have an expression of the form (A & C) | (B & D). If A is a scalar or /// vector composed of all-zeros or all-ones values and is the bitwise 'not' of /// B, it can be used as the condition operand of a select instruction. -Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { +/// We will detect (A & C) | ~(B | D) when the flag ABIsTheSame enabled. +Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B, + bool ABIsTheSame) { // We may have peeked through bitcasts in the caller. // Exit immediately if we don't have (vector) integer types. Type *Ty = A->getType(); @@ -2608,7 +2610,7 @@ return nullptr; // If A is the 'not' operand of B and has enough signbits, we have our answer. - if (match(B, m_Not(m_Specific(A)))) { + if (ABIsTheSame ? (A == B) : match(B, m_Not(m_Specific(A)))) { // If these are scalars or vectors of i1, A can be used directly. if (Ty->isIntOrIntVectorTy(1)) return A; @@ -2628,6 +2630,10 @@ return nullptr; } + // TODO: add support for sext and constant case + if (ABIsTheSame) + return nullptr; + // If both operands are constants, see if the constants are inverse bitmasks. Constant *AConst, *BConst; if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst))) @@ -2676,14 +2682,17 @@ /// We have an expression of the form (A & C) | (B & D). Try to simplify this /// to "A' ? C : D", where A' is a boolean or vector of booleans. +/// When InvertFalseVal is set to true, we try to match the pattern +/// where we have peeked through a 'not' op and A and B are the same: +/// (A & C) | ~(A | D) --> (A & C) | (~A & ~D) --> A' ? C : ~D Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, - Value *D) { + Value *D, bool InvertFalseVal) { // The potential condition of the select may be bitcasted. In that case, look // through its bitcast and the corresponding bitcast of the 'not' condition. Type *OrigType = A->getType(); A = peekThroughBitcast(A, true); B = peekThroughBitcast(B, true); - if (Value *Cond = getSelectCondition(A, B)) { + if (Value *Cond = getSelectCondition(A, B, InvertFalseVal)) { // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) // If this is a vector, we may need to cast to match the condition's length. // The bitcasts will either all exist or all not exist. The builder will @@ -2699,6 +2708,8 @@ SelTy = VectorType::get(EltTy, VecTy->getElementCount()); } Value *BitcastC = Builder.CreateBitCast(C, SelTy); + if (InvertFalseVal) + D = Builder.CreateNot(D); Value *BitcastD = Builder.CreateBitCast(D, SelTy); Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD); return Builder.CreateBitCast(Select, OrigType); @@ -3087,6 +3098,20 @@ } } + if (match(Op0, m_And(m_Value(A), m_Value(C))) && + match(Op1, m_Not(m_Or(m_Value(B), m_Value(D)))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + // (Cond & C) | ~(Cond | D) -> Cond ? C : ~D + if (Value *V = matchSelectFromAndOr(A, C, B, D, true)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(A, C, D, B, true)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(C, A, B, D, true)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(C, A, D, B, true)) + return replaceInstUsesWith(I, V); + } + // (A ^ B) | ((B ^ C) ^ A) -> (A ^ B) | C if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -371,8 +371,9 @@ Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI, bool IsAnd, bool IsLogical = false); - Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D); - Value *getSelectCondition(Value *A, Value *B); + Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D, + bool InvertFalseVal = false); + Value *getSelectCondition(Value *A, Value *B, bool ABIsTheSame); Instruction *foldExtractOfOverflowIntrinsic(ExtractValueInst &EV); Instruction *foldIntrinsicWithOverflowCommon(IntrinsicInst *II); diff --git a/llvm/test/Transforms/InstCombine/logical-select.ll b/llvm/test/Transforms/InstCombine/logical-select.ll --- a/llvm/test/Transforms/InstCombine/logical-select.ll +++ b/llvm/test/Transforms/InstCombine/logical-select.ll @@ -991,10 +991,8 @@ define i1 @not_d_bools_commute00(i1 %c, i1 %x, i1 %y) { ; CHECK-LABEL: @not_d_bools_commute00( -; CHECK-NEXT: [[Y_C:%.*]] = or i1 [[C:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[AND2:%.*]] = xor i1 [[Y_C]], true -; CHECK-NEXT: [[AND1:%.*]] = and i1 [[C]], [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = or i1 [[AND1]], [[AND2]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[Y:%.*]], true +; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[X:%.*]], i1 [[TMP1]] ; CHECK-NEXT: ret i1 [[R]] ; %y_c = or i1 %c, %y @@ -1006,10 +1004,8 @@ define i1 @not_d_bools_commute01(i1 %c, i1 %x, i1 %y) { ; CHECK-LABEL: @not_d_bools_commute01( -; CHECK-NEXT: [[Y_C:%.*]] = or i1 [[Y:%.*]], [[C:%.*]] -; CHECK-NEXT: [[AND2:%.*]] = xor i1 [[Y_C]], true -; CHECK-NEXT: [[AND1:%.*]] = and i1 [[C]], [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = or i1 [[AND1]], [[AND2]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[Y:%.*]], true +; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[X:%.*]], i1 [[TMP1]] ; CHECK-NEXT: ret i1 [[R]] ; %y_c = or i1 %y, %c @@ -1021,10 +1017,8 @@ define i1 @not_d_bools_commute10(i1 %c, i1 %x, i1 %y) { ; CHECK-LABEL: @not_d_bools_commute10( -; CHECK-NEXT: [[Y_C:%.*]] = or i1 [[C:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[AND2:%.*]] = xor i1 [[Y_C]], true -; CHECK-NEXT: [[AND1:%.*]] = and i1 [[X:%.*]], [[C]] -; CHECK-NEXT: [[R:%.*]] = or i1 [[AND1]], [[AND2]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[Y:%.*]], true +; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[X:%.*]], i1 [[TMP1]] ; CHECK-NEXT: ret i1 [[R]] ; %y_c = or i1 %c, %y @@ -1036,10 +1030,8 @@ define i1 @not_d_bools_commute11(i1 %c, i1 %x, i1 %y) { ; CHECK-LABEL: @not_d_bools_commute11( -; CHECK-NEXT: [[Y_C:%.*]] = or i1 [[Y:%.*]], [[C:%.*]] -; CHECK-NEXT: [[AND2:%.*]] = xor i1 [[Y_C]], true -; CHECK-NEXT: [[AND1:%.*]] = and i1 [[X:%.*]], [[C]] -; CHECK-NEXT: [[R:%.*]] = or i1 [[AND1]], [[AND2]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[Y:%.*]], true +; CHECK-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i1 [[X:%.*]], i1 [[TMP1]] ; CHECK-NEXT: ret i1 [[R]] ; %y_c = or i1 %y, %c @@ -1051,10 +1043,8 @@ define <2 x i1> @not_d_bools_vector(<2 x i1> %c, <2 x i1> %x, <2 x i1> %y) { ; CHECK-LABEL: @not_d_bools_vector( -; CHECK-NEXT: [[Y_C:%.*]] = or <2 x i1> [[Y:%.*]], [[C:%.*]] -; CHECK-NEXT: [[AND2:%.*]] = xor <2 x i1> [[Y_C]], -; CHECK-NEXT: [[AND1:%.*]] = and <2 x i1> [[X:%.*]], [[C]] -; CHECK-NEXT: [[R:%.*]] = or <2 x i1> [[AND1]], [[AND2]] +; CHECK-NEXT: [[TMP1:%.*]] = xor <2 x i1> [[Y:%.*]], +; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[C:%.*]], <2 x i1> [[X:%.*]], <2 x i1> [[TMP1]] ; CHECK-NEXT: ret <2 x i1> [[R]] ; %y_c = or <2 x i1> %y, %c @@ -1066,10 +1056,8 @@ define <2 x i1> @not_d_bools_vector_poison(<2 x i1> %c, <2 x i1> %x, <2 x i1> %y) { ; CHECK-LABEL: @not_d_bools_vector_poison( -; CHECK-NEXT: [[Y_C:%.*]] = or <2 x i1> [[Y:%.*]], [[C:%.*]] -; CHECK-NEXT: [[AND2:%.*]] = xor <2 x i1> [[Y_C]], -; CHECK-NEXT: [[AND1:%.*]] = and <2 x i1> [[X:%.*]], [[C]] -; CHECK-NEXT: [[R:%.*]] = or <2 x i1> [[AND1]], [[AND2]] +; CHECK-NEXT: [[TMP1:%.*]] = xor <2 x i1> [[Y:%.*]], +; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[C:%.*]], <2 x i1> [[X:%.*]], <2 x i1> [[TMP1]] ; CHECK-NEXT: ret <2 x i1> [[R]] ; %y_c = or <2 x i1> %y, %c @@ -1081,11 +1069,9 @@ define i32 @not_d_allSignBits(i32 %cond, i32 %tval, i32 %fval) { ; CHECK-LABEL: @not_d_allSignBits( -; CHECK-NEXT: [[BITMASK:%.*]] = ashr i32 [[COND:%.*]], 31 -; CHECK-NEXT: [[A1:%.*]] = and i32 [[BITMASK]], [[TVAL:%.*]] -; CHECK-NEXT: [[OR:%.*]] = or i32 [[BITMASK]], [[FVAL:%.*]] -; CHECK-NEXT: [[A2:%.*]] = xor i32 [[OR]], -1 -; CHECK-NEXT: [[SEL:%.*]] = or i32 [[A1]], [[A2]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i32 [[FVAL:%.*]], -1 +; CHECK-NEXT: [[DOTNOT2:%.*]] = icmp slt i32 [[COND:%.*]], 0 +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[DOTNOT2]], i32 [[TVAL:%.*]], i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[SEL]] ; %bitmask = ashr i32 %cond, 31 @@ -1099,9 +1085,9 @@ define i1 @not_d_bools_use2(i1 %c, i1 %x, i1 %y) { ; CHECK-LABEL: @not_d_bools_use2( ; CHECK-NEXT: [[Y_C:%.*]] = or i1 [[C:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[AND2:%.*]] = xor i1 [[Y_C]], true ; CHECK-NEXT: [[AND1:%.*]] = and i1 [[C]], [[X:%.*]] -; CHECK-NEXT: [[R:%.*]] = or i1 [[AND1]], [[AND2]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[Y]], true +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i1 [[X]], i1 [[TMP1]] ; CHECK-NEXT: call void @use1(i1 [[AND1]]) ; CHECK-NEXT: call void @use1(i1 [[Y_C]]) ; CHECK-NEXT: ret i1 [[R]] @@ -1115,6 +1101,8 @@ ret i1 %r } +; negative test: both op is not one-use + define i1 @not_d_bools_negative_use2(i1 %c, i1 %x, i1 %y) { ; CHECK-LABEL: @not_d_bools_negative_use2( ; CHECK-NEXT: [[Y_C:%.*]] = or i1 [[C:%.*]], [[Y:%.*]]