diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3174,6 +3174,14 @@ } } + // select (~b & a), a, b -> or a, b + // only for scalar types + if (match(CondVal, m_And(m_Not(m_Specific(FalseVal)), m_Specific(TrueVal))) && + TrueVal->getType()->isIntegerTy(1) && + FalseVal->getType()->isIntegerTy(1)) { + return BinaryOperator::CreateOr(TrueVal, FalseVal); + } + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/fold-xor-and-select-i1.ll b/llvm/test/Transforms/InstCombine/fold-xor-and-select-i1.ll --- a/llvm/test/Transforms/InstCombine/fold-xor-and-select-i1.ll +++ b/llvm/test/Transforms/InstCombine/fold-xor-and-select-i1.ll @@ -4,10 +4,8 @@ define i1 @max_if(i1 %a, i1 %b) { ; CHECK-LABEL: define i1 @max_if ; CHECK-SAME: (i1 [[A:%.*]], i1 [[B:%.*]]) { -; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[B]], true -; CHECK-NEXT: [[CMP:%.*]] = and i1 [[TMP1]], [[A]] -; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[CMP]], i1 [[A]], i1 [[B]] -; CHECK-NEXT: ret i1 [[TMP2]] +; CHECK-NEXT: [[TMP1:%.*]] = or i1 [[A]], [[B]] +; CHECK-NEXT: ret i1 [[TMP1]] ; %1 = xor i1 %b, true %cmp = and i1 %1, %a