Index: llvm/lib/Transforms/Scalar/MergeICmps.cpp =================================================================== --- llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -330,10 +330,10 @@ // Visit the given comparison block. If this is a comparison between two valid // BCE atoms, returns the comparison. -std::optional visitCmpBlock(Value *const Val, - BasicBlock *const Block, - const BasicBlock *const PhiBlock, - BaseIdentifier &BaseId) { +std::optional +visitCmpBlock(Value *const Baseline, ICmpInst::Predicate &Predicate, + Value *const Val, BasicBlock *const Block, + const BasicBlock *const PhiBlock, BaseIdentifier &BaseId) { if (Block->empty()) return std::nullopt; auto *const BranchI = dyn_cast(Block->getTerminator()); @@ -348,15 +348,23 @@ // that this does not mean that this is the last incoming value, blocks // can be reordered). Cond = Val; - ExpectedPredicate = ICmpInst::ICMP_EQ; + const auto *const ConstBase = cast(Baseline); + ExpectedPredicate = + ConstBase->isOne() ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; + + // Remember the correct predicate. + Predicate = ExpectedPredicate; } else { + // All the incoming values must be consistent. + if (Baseline != Val) + return std::nullopt; // In this case, we expect a constant incoming value (the comparison is // chained). const auto *const Const = cast(Val); LLVM_DEBUG(dbgs() << "const\n"); - if (!Const->isZero()) + if (!Const->isZero() && !Const->isOne()) return std::nullopt; - LLVM_DEBUG(dbgs() << "false\n"); + LLVM_DEBUG(dbgs() << *Const << "\n"); assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch"); BasicBlock *const FalseBlock = BranchI->getSuccessor(1); Cond = BranchI->getCondition(); @@ -401,10 +409,10 @@ using ContiguousBlocks = std::vector; BCECmpChain(const std::vector &Blocks, PHINode &Phi, - AliasAnalysis &AA); + AliasAnalysis &AA, ICmpInst::Predicate &Predicate); bool simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, - DomTreeUpdater &DTU); + DomTreeUpdater &DTU, ICmpInst::Predicate Predicate); bool atLeastOneMerged() const { return any_of(MergedBlocks_, @@ -469,16 +477,18 @@ } BCECmpChain::BCECmpChain(const std::vector &Blocks, PHINode &Phi, - AliasAnalysis &AA) + AliasAnalysis &AA, ICmpInst::Predicate &Predicate) : Phi_(Phi) { assert(!Blocks.empty() && "a chain should have at least one block"); // Now look inside blocks to check for BCE comparisons. std::vector Comparisons; BaseIdentifier BaseId; + Value *const Baseline = Phi.getIncomingValueForBlock(Blocks[0]); for (BasicBlock *const Block : Blocks) { assert(Block && "invalid block"); - std::optional Comparison = visitCmpBlock( - Phi.getIncomingValueForBlock(Block), Block, Phi.getParent(), BaseId); + std::optional Comparison = + visitCmpBlock(Baseline, Predicate, Phi.getIncomingValueForBlock(Block), + Block, Phi.getParent(), BaseId); if (!Comparison) { LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n"); return; @@ -602,7 +612,8 @@ BasicBlock *const InsertBefore, BasicBlock *const NextCmpBlock, PHINode &Phi, const TargetLibraryInfo &TLI, - AliasAnalysis &AA, DomTreeUpdater &DTU) { + AliasAnalysis &AA, DomTreeUpdater &DTU, + ICmpInst::Predicate Predicate) { assert(!Comparisons.empty() && "merging zero comparisons"); LLVMContext &Context = NextCmpBlock->getContext(); const BCECmpBlock &FirstCmp = Comparisons[0]; @@ -644,7 +655,7 @@ Value *const RhsLoad = Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs); // There are no blocks to merge, just do the comparison. - IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad); + IsEqual = Builder.CreateICmp(Predicate, LhsLoad, RhsLoad); } else { const unsigned TotalSizeBits = std::accumulate( Comparisons.begin(), Comparisons.end(), 0u, @@ -660,8 +671,8 @@ Lhs, Rhs, ConstantInt::get(Builder.getIntNTy(SizeTBits), TotalSizeBits / 8), Builder, DL, &TLI); - IsEqual = Builder.CreateICmpEQ( - MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0)); + IsEqual = Builder.CreateICmp( + Predicate, MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0)); } BasicBlock *const PhiBB = Phi.getParent(); @@ -682,7 +693,7 @@ } bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, - DomTreeUpdater &DTU) { + DomTreeUpdater &DTU, ICmpInst::Predicate Predicate) { assert(atLeastOneMerged() && "simplifying trivial BCECmpChain"); LLVM_DEBUG(dbgs() << "Simplifying comparison chain starting at block " << EntryBlock_->getName() << "\n"); @@ -693,7 +704,7 @@ BasicBlock *NextCmpBlock = Phi_.getParent(); for (const auto &Blocks : reverse(MergedBlocks_)) { InsertBefore = NextCmpBlock = mergeComparisons( - Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU); + Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU, Predicate); } // Replace the original cmp chain with the new cmp chain by pointing all @@ -829,14 +840,17 @@ const auto Blocks = getOrderedBlocks(Phi, LastBlock, Phi.getNumIncomingValues()); if (Blocks.empty()) return false; - BCECmpChain CmpChain(Blocks, Phi, AA); + ICmpInst::Predicate Predicate = CmpInst::BAD_ICMP_PREDICATE; + BCECmpChain CmpChain(Blocks, Phi, AA, Predicate); if (!CmpChain.atLeastOneMerged()) { LLVM_DEBUG(dbgs() << "skip: nothing merged\n"); return false; } - return CmpChain.simplify(TLI, AA, DTU); + assert(Predicate != CmpInst::BAD_ICMP_PREDICATE && + "Got the chain of comparisons"); + return CmpChain.simplify(TLI, AA, DTU, Predicate); } static bool runImpl(Function &F, const TargetLibraryInfo &TLI, Index: llvm/test/Transforms/MergeICmps/X86/pr59740.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/MergeICmps/X86/pr59740.ll @@ -0,0 +1,86 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=mergeicmps -verify-dom-info -S -mtriple=x86_64-unknown-unknown | FileCheck %s + +%struct.S = type { i8, i8, i8, i8 } + +define noundef i1 @_Z2neR1SS0_(ptr nocapture readonly align 1 dereferenceable(4) %s0, ptr nocapture readonly align 1 dereferenceable(4) %s1) { +; CHECK-LABEL: @_Z2neR1SS0_( +; CHECK-NEXT: "bb0+bb1+bb2+bb3": +; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[S0:%.*]], ptr [[S1:%.*]], i64 4) +; CHECK-NEXT: [[TMP0:%.*]] = icmp ne i32 [[MEMCMP]], 0 +; CHECK-NEXT: br label [[BB4:%.*]] +; CHECK: bb4: +; CHECK-NEXT: ret i1 [[TMP0]] +; +bb0: + %v0 = load i8, ptr %s0, align 1 + %v1 = load i8, ptr %s1, align 1 + %cmp0 = icmp eq i8 %v0, %v1 + br i1 %cmp0, label %bb1, label %bb4 + +bb1: ; preds = %bb0 + %s2 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 1 + %v2 = load i8, ptr %s2, align 1 + %s3 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 1 + %v3 = load i8, ptr %s3, align 1 + %cmp1 = icmp eq i8 %v2, %v3 + br i1 %cmp1, label %bb2, label %bb4 + +bb2: ; preds = %bb1 + %s4 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 2 + %v4 = load i8, ptr %s4, align 1 + %s5 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 2 + %v5 = load i8, ptr %s5, align 1 + %cmp2 = icmp eq i8 %v4, %v5 + br i1 %cmp2, label %bb3, label %bb4 + +bb3: ; preds = %bb2 + %s6 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 3 + %v6 = load i8, ptr %s6, align 1 + %s7 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 3 + %v7 = load i8, ptr %s7, align 1 + %cmp3 = icmp ne i8 %v6, %v7 + br label %bb4 + +bb4: ; preds = %bb0, %bb1, %bb2, %bb3 + %cmp = phi i1 [ true, %bb0 ], [ true, %bb1 ], [ true, %bb2 ], [ %cmp3, %bb3 ] + ret i1 %cmp +} + +; Negative test: Incorrect const value in PHI node +define noundef i1 @cmp_ne_incorrect_const(ptr nocapture readonly align 1 dereferenceable(4) %s0, ptr nocapture readonly align 1 dereferenceable(4) %s1) { +; CHECK-LABEL: @cmp_ne_incorrect_const( +; CHECK-NEXT: bb0: +; CHECK-NEXT: [[V0:%.*]] = load i8, ptr [[S0:%.*]], align 1 +; CHECK-NEXT: [[V1:%.*]] = load i8, ptr [[S1:%.*]], align 1 +; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i8 [[V0]], [[V1]] +; CHECK-NEXT: br i1 [[CMP0]], label [[BB1:%.*]], label [[BB2:%.*]] +; CHECK: bb1: +; CHECK-NEXT: [[S6:%.*]] = getelementptr inbounds [[STRUCT_S:%.*]], ptr [[S0]], i64 0, i32 1 +; CHECK-NEXT: [[V6:%.*]] = load i8, ptr [[S6]], align 1 +; CHECK-NEXT: [[S7:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[S1]], i64 0, i32 1 +; CHECK-NEXT: [[V7:%.*]] = load i8, ptr [[S7]], align 1 +; CHECK-NEXT: [[CMP3:%.*]] = icmp ne i8 [[V6]], [[V7]] +; CHECK-NEXT: br label [[BB2]] +; CHECK: bb2: +; CHECK-NEXT: [[CMP:%.*]] = phi i1 [ false, [[BB0:%.*]] ], [ [[CMP3]], [[BB1]] ] +; CHECK-NEXT: ret i1 [[CMP]] +; +bb0: + %v0 = load i8, ptr %s0, align 1 + %v1 = load i8, ptr %s1, align 1 + %cmp0 = icmp eq i8 %v0, %v1 + br i1 %cmp0, label %bb1, label %bb2 + +bb1: ; preds = %bb0 + %s6 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 1 + %v6 = load i8, ptr %s6, align 1 + %s7 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 1 + %v7 = load i8, ptr %s7, align 1 + %cmp3 = icmp ne i8 %v6, %v7 + br label %bb2 + +bb2: ; preds = %bb0, %bb1 + %cmp = phi i1 [ false, %bb0 ], [ %cmp3, %bb1 ] + ret i1 %cmp +}