diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -953,6 +953,9 @@ return nullptr; } +/// Transform (zext icmp) to bitwise / integer operations in order to +/// eliminate it. If DoTransform is false, just test whether the given +/// (zext icmp) can be transformed. Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, bool DoTransform) { // If we are just checking for a icmp eq of a single bit and zext'ing it @@ -1039,6 +1042,9 @@ if (Cmp->hasOneUse() && match(Cmp->getOperand(1), m_ZeroInt()) && match(Cmp->getOperand(0), m_OneUse(m_c_And(m_Shl(m_One(), m_Value(ShAmt)), m_Value(X))))) { + if (!DoTransform) + return Cmp; + if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) X = Builder.CreateNot(X); Value *Lshr = Builder.CreateLShr(X, ShAmt); diff --git a/llvm/test/Transforms/InstCombine/zext.ll b/llvm/test/Transforms/InstCombine/zext.ll --- a/llvm/test/Transforms/InstCombine/zext.ll +++ b/llvm/test/Transforms/InstCombine/zext.ll @@ -409,3 +409,29 @@ %r = zext i1 %cmp to i32 ret i32 %r } + +; Assert that zext(or(masked_bit_test, icmp)) can be correctly transformed to +; or(shifted_masked_bit, zext(icmp)) + +define void @zext_or_masked_bit_test(i32 %a, i32 %b, i32* %p) { +; CHECK-LABEL: @zext_or_masked_bit_test +; CHECK-NEXT: [[LD:%.*]] = load i32, i32* %p, align 4 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[LD]], %b +; CHECK-NEXT: [[SHR:%.*]] = lshr i32 %a, %b +; CHECK-NEXT: [[AND:%.*]] = and i32 [[SHR]], 1 +; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[CMP]] to i32 +; CHECK-NEXT: [[OR:%.*]] = or i32 [[AND]], [[EXT]] +; CHECK-NEXT: store i32 [[OR]], i32* %p, align 4 +; CHECK-NEXT: ret void +; + %ld = load i32, i32* %p, align 4 + %shl = shl i32 1, %b + %and = and i32 %shl, %a + %tobool = icmp ne i32 %and, 0 + %cmp = icmp eq i32 %ld, %b + %or = or i1 %tobool, %cmp + %conv = zext i1 %or to i32 + store i32 %conv, i32* %p, align 4 + ret void +} +