diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -625,13 +625,11 @@ return RHS; } - if (Mask & BMask_Mixed) { - // (icmp eq (A & B), C) & (icmp eq (A & D), E) + if (Mask & (BMask_Mixed | BMask_NotMixed)) { // We already know that B & C == C && D & E == E. // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of // C and E, which are shared by both the mask B and the mask D, don't - // contradict, then we can transform to - // -> (icmp eq (A & (B|D)), (C|E)) + // contradict, then we can transform to checking one side. // Currently, we only handle the case of B, C, D, and E being constant. // We can't simply use C and E because we might actually handle // (icmp ne (A & B), B) & (icmp eq (A & D), D) @@ -643,15 +641,33 @@ const APInt ConstC = PredL != NewCC ? *ConstB ^ *OldConstC : *OldConstC; const APInt ConstE = PredR != NewCC ? *ConstD ^ *OldConstE : *OldConstE; - // If there is a conflict, we should actually return a false for the - // whole construct. - if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) - return ConstantInt::get(LHS->getType(), !IsAnd); + if (Mask & BMask_Mixed) { + // Check the union (B|D) for equality + // (icmp eq (A & B), C) & (icmp eq (A & D), E) + // -> (icmp eq (A & (B|D)), (C|E)) + + // If there is a conflict, we should actually return a false for the + // whole construct. + if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) + return ConstantInt::get(LHS->getType(), !IsAnd); + + Value *NewOr1 = Builder.CreateOr(B, D); + Value *NewAnd = Builder.CreateAnd(A, NewOr1); + Constant *NewOr2 = ConstantInt::get(A->getType(), ConstC | ConstE); + return Builder.CreateICmp(NewCC, NewAnd, NewOr2); + } + // Check the intersection (B&D) for inequality. + // (icmp ne (A & B), C) & (icmp ne (A & D), E) + // -> (icmp ne (A & (B&D)), (C&E)) - Value *NewOr1 = Builder.CreateOr(B, D); - Value *NewAnd = Builder.CreateAnd(A, NewOr1); - Constant *NewOr2 = ConstantInt::get(A->getType(), ConstC | ConstE); - return Builder.CreateICmp(NewCC, NewAnd, NewOr2); + // If there is a conflict, we still have to check both sides. + if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) + return nullptr; + Value *NewAnd1 = Builder.CreateAnd(B, D); + Value *NewAnd2 = Builder.CreateAnd(A, NewAnd1); + Constant *NewAnd3 = ConstantInt::get(A->getType(), ConstC & ConstE); + return Builder.CreateICmp(CmpInst::getInversePredicate(NewCC), NewAnd2, + NewAnd3); } return nullptr; diff --git a/llvm/test/Transforms/InstCombine/icmp-logical.ll b/llvm/test/Transforms/InstCombine/icmp-logical.ll --- a/llvm/test/Transforms/InstCombine/icmp-logical.ll +++ b/llvm/test/Transforms/InstCombine/icmp-logical.ll @@ -243,13 +243,10 @@ ret i1 %res } -define i1 @masked_or_allzeroes_notoptimised(i32 %A) { -; CHECK-LABEL: @masked_or_allzeroes_notoptimised( -; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[A:%.*]], 15 -; CHECK-NEXT: [[TST1:%.*]] = icmp eq i32 [[MASK1]], 0 -; CHECK-NEXT: [[MASK2:%.*]] = and i32 [[A]], 39 -; CHECK-NEXT: [[TST2:%.*]] = icmp eq i32 [[MASK2]], 0 -; CHECK-NEXT: [[RES:%.*]] = or i1 [[TST1]], [[TST2]] +define i1 @masked_or_allzeroes_eq(i32 %A) { +; CHECK-LABEL: @masked_or_allzeroes_eq( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 7 +; CHECK-NEXT: [[RES:%.*]] = icmp eq i32 [[TMP1]], 7 ; CHECK-NEXT: ret i1 [[RES]] ; %mask1 = and i32 %A, 15 @@ -260,13 +257,10 @@ ret i1 %res } -define i1 @masked_or_allzeroes_notoptimised_logical(i32 %A) { -; CHECK-LABEL: @masked_or_allzeroes_notoptimised_logical( -; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[A:%.*]], 15 -; CHECK-NEXT: [[TST1:%.*]] = icmp eq i32 [[MASK1]], 0 -; CHECK-NEXT: [[MASK2:%.*]] = and i32 [[A]], 39 -; CHECK-NEXT: [[TST2:%.*]] = icmp eq i32 [[MASK2]], 0 -; CHECK-NEXT: [[RES:%.*]] = or i1 [[TST1]], [[TST2]] +define i1 @masked_or_allzeroes_eq_logical(i32 %A) { +; CHECK-LABEL: @masked_or_allzeroes_eq_logical( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[A:%.*]], 7 +; CHECK-NEXT: [[RES:%.*]] = icmp eq i32 [[TMP1]], 7 ; CHECK-NEXT: ret i1 [[RES]] ; %mask1 = and i32 %A, 15 @@ -277,6 +271,35 @@ ret i1 %res } +define i1 @masked_or_eq(i32 %A) { +; CHECK-LABEL: @masked_or_eq( +; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[A:%.*]], 7 +; CHECK-NEXT: [[TST1:%.*]] = icmp eq i32 [[MASK1]], 2 +; CHECK-NEXT: ret i1 [[TST1]] +; + %mask1 = and i32 %A, 15 + %tst1 = icmp eq i32 %mask1, 5 + %mask2 = and i32 %A, 39 + %tst2 = icmp eq i32 %mask2, 37 + %res = or i1 %tst1, %tst2 + ret i1 %res +} + +define i1 @masked_and_ne(i32 %A) { +; CHECK-LABEL: @masked_and_ne( +; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[A:%.*]], 7 +; CHECK-NEXT: [[TST1:%.*]] = icmp ne i32 [[MASK1]], 2 +; CHECK-NEXT: ret i1 [[TST1]] +; + %mask1 = and i32 %A, 15 + %tst1 = icmp ne i32 %mask1, 5 + %mask2 = and i32 %A, 39 + %tst2 = icmp ne i32 %mask2, 37 + %res = and i1 %tst1, %tst2 + ret i1 %res +} + + define i1 @nomask_lhs(i32 %in) { ; CHECK-LABEL: @nomask_lhs( ; CHECK-NEXT: [[MASKED:%.*]] = and i32 [[IN:%.*]], 1