diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -625,7 +625,8 @@ return RHS; } - if (Mask & BMask_Mixed) { + if (Mask & (BMask_Mixed | BMask_NotMixed)) { + // Mixed: // (icmp eq (A & B), C) & (icmp eq (A & D), E) // We already know that B & C == C && D & E == E. // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of @@ -636,24 +637,50 @@ // We can't simply use C and E because we might actually handle // (icmp ne (A & B), B) & (icmp eq (A & D), D) // with B and D, having a single bit set. + + // NotMixed: + // (icmp ne (A & B), C) & (icmp ne (A & D), E) + // -> (icmp ne (A & (B & D)), (C & E)) + // Check the intersection (B & D) for inequality. + // Assume that (B & D) == B || (B & D) == D, i.e B/D is a subset of D/B + // and (B & D) & (C ^ E) == 0, bits of C and E, which are shared by both the + // B and the D, don't contradict. + // Note that we can assume (~B & C) == 0 && (~D & E) == 0, previous + // operation should delete these icmps if it hadn't been met. + const APInt *OldConstC, *OldConstE; if (!match(C, m_APInt(OldConstC)) || !match(E, m_APInt(OldConstE))) return nullptr; - const APInt ConstC = PredL != NewCC ? *ConstB ^ *OldConstC : *OldConstC; - const APInt ConstE = PredR != NewCC ? *ConstD ^ *OldConstE : *OldConstE; + auto FoldBMixed = [&](ICmpInst::Predicate CC, bool IsNot) -> Value * { + CC = IsNot ? CmpInst::getInversePredicate(CC) : CC; + const APInt ConstC = PredL != CC ? *ConstB ^ *OldConstC : *OldConstC; + const APInt ConstE = PredR != CC ? *ConstD ^ *OldConstE : *OldConstE; - // If there is a conflict, we should actually return a false for the - // whole construct. - if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) - return ConstantInt::get(LHS->getType(), !IsAnd); + if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) + return IsNot ? nullptr : ConstantInt::get(LHS->getType(), !IsAnd); - Value *NewOr1 = Builder.CreateOr(B, D); - Value *NewAnd = Builder.CreateAnd(A, NewOr1); - Constant *NewOr2 = ConstantInt::get(A->getType(), ConstC | ConstE); - return Builder.CreateICmp(NewCC, NewAnd, NewOr2); - } + if (IsNot && !ConstB->isSubsetOf(*ConstD) && !ConstD->isSubsetOf(*ConstB)) + return nullptr; + APInt BD, CE; + if (IsNot) { + BD = *ConstB & *ConstD; + CE = ConstC & ConstE; + } else { + BD = *ConstB | *ConstD; + CE = ConstC | ConstE; + } + Value *NewAnd = Builder.CreateAnd(A, BD); + Value *CEVal = ConstantInt::get(A->getType(), CE); + return Builder.CreateICmp(CC, CEVal, NewAnd); + }; + + if (Mask & BMask_Mixed) + return FoldBMixed(NewCC, false); + if (Mask & BMask_NotMixed) // can be else also + return FoldBMixed(NewCC, true); + } return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/icmp-logical.ll b/llvm/test/Transforms/InstCombine/icmp-logical.ll --- a/llvm/test/Transforms/InstCombine/icmp-logical.ll +++ b/llvm/test/Transforms/InstCombine/icmp-logical.ll @@ -280,11 +280,8 @@ define i1 @masked_or_eq(i32 %A) { ; CHECK-LABEL: @masked_or_eq( -; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[A:%.*]], 15 -; CHECK-NEXT: [[TST1:%.*]] = icmp eq i32 [[MASK1]], 3 -; CHECK-NEXT: [[MASK2:%.*]] = and i32 [[A]], 255 -; CHECK-NEXT: [[TST2:%.*]] = icmp eq i32 [[MASK2]], 243 -; CHECK-NEXT: [[RES:%.*]] = or i1 [[TST1]], [[TST2]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 15 +; CHECK-NEXT: [[RES:%.*]] = icmp eq i32 [[TMP1]], 3 ; CHECK-NEXT: ret i1 [[RES]] ; %mask1 = and i32 %A, 15 ; 0x0f @@ -314,11 +311,8 @@ define i1 @masked_and_ne(i32 %A) { ; CHECK-LABEL: @masked_and_ne( -; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[A:%.*]], 15 -; CHECK-NEXT: [[TST1:%.*]] = icmp ne i32 [[MASK1]], 3 -; CHECK-NEXT: [[MASK2:%.*]] = and i32 [[A]], 255 -; CHECK-NEXT: [[TST2:%.*]] = icmp ne i32 [[MASK2]], 243 -; CHECK-NEXT: [[RES:%.*]] = and i1 [[TST1]], [[TST2]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 15 +; CHECK-NEXT: [[RES:%.*]] = icmp ne i32 [[TMP1]], 3 ; CHECK-NEXT: ret i1 [[RES]] ; %mask1 = and i32 %A, 15 ; 0x0f