diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -625,17 +625,17 @@ return RHS; } - if (Mask & BMask_Mixed) { + if (Mask & BMask_Mixed || Mask & BMask_NotMixed) { + // (icmp eq (A & B), C) &/| (icmp eq (A & D), E) + // (icmp eq (A & B), C) & (icmp eq (A & D), E) - // We already know that B & C == C && D & E == E. - // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of - // C and E, which are shared by both the mask B and the mask D, don't - // contradict, then we can transform to // -> (icmp eq (A & (B|D)), (C|E)) - // Currently, we only handle the case of B, C, D, and E being constant. - // We can't simply use C and E because we might actually handle - // (icmp ne (A & B), B) & (icmp eq (A & D), D) - // with B and D, having a single bit set. + + // (icmp eq (A & B), C) | (icmp eq (A & D), E) + // -> (icmp eq (A & (B & D)), (C & E)) + + if (Mask & BMask_NotMixed) + NewCC = CmpInst::getInversePredicate(NewCC); const APInt *OldConstC, *OldConstE; if (!match(C, m_APInt(OldConstC)) || !match(E, m_APInt(OldConstE))) return nullptr; @@ -643,17 +643,53 @@ const APInt ConstC = PredL != NewCC ? *ConstB ^ *OldConstC : *OldConstC; const APInt ConstE = PredR != NewCC ? *ConstD ^ *OldConstE : *OldConstE; - // If there is a conflict, we should actually return a false for the - // whole construct. + if (Mask & BMask_Mixed) { + // (icmp eq (A & B), C) & (icmp eq (A & D), E) + // We already know that B & C == C && D & E == E. + // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of + // C and E, which are shared by both the mask B and the mask D, don't + // contradict, then we can transform to + // -> (icmp eq (A & (B|D)), (C|E)) + // Currently, we only handle the case of B, C, D, and E being constant. + // We can't simply use C and E because we might actually handle + // (icmp ne (A & B), B) & (icmp eq (A & D), D) + // with B and D, having a single bit set. + + // If there is a conflict, we should actually return a false for the + // whole construct. + if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) + return ConstantInt::get(LHS->getType(), !IsAnd); + Value *NewOr1 = Builder.CreateOr(B, D); + Value *NewAnd = Builder.CreateAnd(A, NewOr1); + Constant *NewOr2 = ConstantInt::get(A->getType(), ConstC | ConstE); + return Builder.CreateICmp(NewCC, NewAnd, NewOr2); + } + // Mask & BMask_NotMixed + + // Check the intersection (B & D) for inequality. + // Assume that (B & D) == B || (B & D) == D, i.e B/D is a subset of D/B + // and (B & D) & (C ^ E) == 0, bits of C and E, which are shared by both the + // B and the D, don't contradict. + // Note that we can assume (~B & C) == 0 && (~D & E) == 0, previous + // operation should delete these icmps if it hadn't been met. + + // (icmp ne (A & B), C) & (icmp ne (A & D), E) + // -> (icmp ne (A & (B & D)), (C & E)) + const APInt ConstT = *ConstB & *ConstD; + if (ConstT != *ConstB && ConstT != *ConstD) + return nullptr; + + if ((~*ConstB & ConstC) != 0 || (~*ConstD & ConstE) != 0) + return nullptr; + + // If there is a conflict, we still have to check both sides. if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) - return ConstantInt::get(LHS->getType(), !IsAnd); + return nullptr; - Value *NewOr1 = Builder.CreateOr(B, D); - Value *NewAnd = Builder.CreateAnd(A, NewOr1); - Constant *NewOr2 = ConstantInt::get(A->getType(), ConstC | ConstE); - return Builder.CreateICmp(NewCC, NewAnd, NewOr2); + Value *NewAnd1 = Builder.CreateAnd(A, ConstT); + Value *NewAnd2 = ConstantInt::get(A->getType(), ConstC & ConstE); + return Builder.CreateICmp(NewCC, NewAnd1, NewAnd2); } - return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/icmp-logical.ll b/llvm/test/Transforms/InstCombine/icmp-logical.ll --- a/llvm/test/Transforms/InstCombine/icmp-logical.ll +++ b/llvm/test/Transforms/InstCombine/icmp-logical.ll @@ -280,11 +280,8 @@ define i1 @masked_or_eq(i32 %A) { ; CHECK-LABEL: @masked_or_eq( -; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[A:%.*]], 15 -; CHECK-NEXT: [[TST1:%.*]] = icmp eq i32 [[MASK1]], 3 -; CHECK-NEXT: [[MASK2:%.*]] = and i32 [[A]], 255 -; CHECK-NEXT: [[TST2:%.*]] = icmp eq i32 [[MASK2]], 243 -; CHECK-NEXT: [[RES:%.*]] = or i1 [[TST1]], [[TST2]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 15 +; CHECK-NEXT: [[RES:%.*]] = icmp eq i32 [[TMP1]], 3 ; CHECK-NEXT: ret i1 [[RES]] ; %mask1 = and i32 %A, 15 ; 0x0f @@ -314,11 +311,8 @@ define i1 @masked_and_ne(i32 %A) { ; CHECK-LABEL: @masked_and_ne( -; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[A:%.*]], 15 -; CHECK-NEXT: [[TST1:%.*]] = icmp ne i32 [[MASK1]], 3 -; CHECK-NEXT: [[MASK2:%.*]] = and i32 [[A]], 255 -; CHECK-NEXT: [[TST2:%.*]] = icmp ne i32 [[MASK2]], 243 -; CHECK-NEXT: [[RES:%.*]] = and i1 [[TST1]], [[TST2]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 15 +; CHECK-NEXT: [[RES:%.*]] = icmp ne i32 [[TMP1]], 3 ; CHECK-NEXT: ret i1 [[RES]] ; %mask1 = and i32 %A, 15 ; 0x0f