Index: llvm/lib/Transforms/Scalar/MergeICmps.cpp =================================================================== --- llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -330,7 +330,8 @@ // Visit the given comparison block. If this is a comparison between two valid // BCE atoms, returns the comparison. -std::optional visitCmpBlock(Value *const Val, +std::optional visitCmpBlock(Value *const Baseline, + Value *const Val, BasicBlock *const Block, const BasicBlock *const PhiBlock, BaseIdentifier &BaseId) { @@ -348,15 +349,20 @@ // that this does not mean that this is the last incoming value, blocks // can be reordered). Cond = Val; - ExpectedPredicate = ICmpInst::ICMP_EQ; + const auto *const ConstBase = cast(Baseline); + ExpectedPredicate = + ConstBase->isOne() ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; } else { + // All the incoming values must be consistent. + if (Baseline != Val) + return std::nullopt; // In this case, we expect a constant incoming value (the comparison is // chained). const auto *const Const = cast(Val); LLVM_DEBUG(dbgs() << "const\n"); - if (!Const->isZero()) + if (!Const->isZero() && !Const->isOne()) return std::nullopt; - LLVM_DEBUG(dbgs() << "false\n"); + LLVM_DEBUG(dbgs() << *Const << "\n"); assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch"); BasicBlock *const FalseBlock = BranchI->getSuccessor(1); Cond = BranchI->getCondition(); @@ -475,9 +481,10 @@ // Now look inside blocks to check for BCE comparisons. std::vector Comparisons; BaseIdentifier BaseId; + Value *const Baseline = Phi.getIncomingValueForBlock(Blocks[0]); for (BasicBlock *const Block : Blocks) { assert(Block && "invalid block"); - std::optional Comparison = visitCmpBlock( + std::optional Comparison = visitCmpBlock(Baseline, Phi.getIncomingValueForBlock(Block), Block, Phi.getParent(), BaseId); if (!Comparison) { LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n"); @@ -637,6 +644,19 @@ ToSplit->split(BB, AA); } + BasicBlock *LastBlock = nullptr; + for (unsigned I = 0; I < Phi.getNumIncomingValues(); ++I) { + if (isa(Phi.getIncomingValue(I))) + continue; + // The only non-constant incoming value is from last block. + LastBlock = Phi.getIncomingBlock(I); + } + auto *const BranchI = cast(LastBlock->getTerminator()); + assert(BranchI->isUnconditional() && + "The last link in the chain of comparisons"); + auto *CmpI = cast(Phi.getIncomingValueForBlock(LastBlock)); + ICmpInst::Predicate Predicate = CmpI->getPredicate(); + if (Comparisons.size() == 1) { LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n"); Value *const LhsLoad = @@ -644,7 +664,10 @@ Value *const RhsLoad = Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs); // There are no blocks to merge, just do the comparison. - IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad); + if (Predicate == ICmpInst::ICMP_NE) + IsEqual = Builder.CreateICmpNE(LhsLoad, RhsLoad); + else + IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad); } else { const unsigned TotalSizeBits = std::accumulate( Comparisons.begin(), Comparisons.end(), 0u, @@ -660,8 +683,12 @@ Lhs, Rhs, ConstantInt::get(Builder.getIntNTy(SizeTBits), TotalSizeBits / 8), Builder, DL, &TLI); - IsEqual = Builder.CreateICmpEQ( - MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0)); + if (Predicate == ICmpInst::ICMP_NE) + IsEqual = Builder.CreateICmpNE( + MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0)); + else + IsEqual = Builder.CreateICmpEQ( + MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0)); } BasicBlock *const PhiBB = Phi.getParent(); Index: llvm/test/Transforms/MergeICmps/X86/pr59740.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/MergeICmps/X86/pr59740.ll @@ -0,0 +1,48 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=mergeicmps -verify-dom-info -S -mtriple=x86_64-unknown-unknown | FileCheck %s + +%struct.S = type { i8, i8, i8, i8 } + +define noundef i1 @_Z2neR1SS0_(ptr nocapture readonly align 1 dereferenceable(4) %s0, ptr nocapture readonly align 1 dereferenceable(4) %s1) { +; CHECK-LABEL: @_Z2neR1SS0_( +; CHECK-NEXT: "bb0+bb1+bb2+bb3": +; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[S0:%.*]], ptr [[S1:%.*]], i64 4) +; CHECK-NEXT: [[TMP0:%.*]] = icmp ne i32 [[MEMCMP]], 0 +; CHECK-NEXT: br label [[BB4:%.*]] +; CHECK: bb4: +; CHECK-NEXT: ret i1 [[TMP0]] +; +bb0: + %v0 = load i8, ptr %s0, align 1 + %v1 = load i8, ptr %s1, align 1 + %cmp0 = icmp eq i8 %v0, %v1 + br i1 %cmp0, label %bb1, label %bb4 + +bb1: ; preds = %bb0 + %s2 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 1 + %v2 = load i8, ptr %s2, align 1 + %s3 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 1 + %v3 = load i8, ptr %s3, align 1 + %cmp1 = icmp eq i8 %v2, %v3 + br i1 %cmp1, label %bb2, label %bb4 + +bb2: ; preds = %bb1 + %s4 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 2 + %v4 = load i8, ptr %s4, align 1 + %s5 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 2 + %v5 = load i8, ptr %s5, align 1 + %cmp2 = icmp eq i8 %v4, %v5 + br i1 %cmp2, label %bb3, label %bb4 + +bb3: ; preds = %bb2 + %s6 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 3 + %v6 = load i8, ptr %s6, align 1 + %s7 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 3 + %v7 = load i8, ptr %s7, align 1 + %cmp3 = icmp ne i8 %v6, %v7 + br label %bb4 + +bb4: ; preds = %bb0, %bb1, %bb2, %bb3 + %cmp = phi i1 [ true, %bb0 ], [ true, %bb1 ], [ true, %bb2 ], [ %cmp3, %bb3 ] + ret i1 %cmp +}