diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp --- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -330,10 +330,10 @@ // Visit the given comparison block. If this is a comparison between two valid // BCE atoms, returns the comparison. -std::optional visitCmpBlock(Value *const Val, - BasicBlock *const Block, - const BasicBlock *const PhiBlock, - BaseIdentifier &BaseId) { +std::optional +visitCmpBlock(Value *const Baseline, ICmpInst::Predicate &Predicate, + Value *const Val, BasicBlock *const Block, + const BasicBlock *const PhiBlock, BaseIdentifier &BaseId) { if (Block->empty()) return std::nullopt; auto *const BranchI = dyn_cast(Block->getTerminator()); @@ -348,15 +348,27 @@ // that this does not mean that this is the last incoming value, blocks // can be reordered). Cond = Val; - ExpectedPredicate = ICmpInst::ICMP_EQ; + const auto *const ConstBase = cast(Baseline); + assert(ConstBase->getType()->isIntegerTy(1) && + "Select condition is not an i1?"); + ExpectedPredicate = + ConstBase->isOne() ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; + + // Remember the correct predicate. + Predicate = ExpectedPredicate; } else { + // All the incoming values must be consistent. + if (Baseline != Val) + return std::nullopt; // In this case, we expect a constant incoming value (the comparison is // chained). const auto *const Const = cast(Val); + assert(Const->getType()->isIntegerTy(1) && + "Incoming value is not an i1?"); LLVM_DEBUG(dbgs() << "const\n"); - if (!Const->isZero()) + if (!Const->isZero() && !Const->isOne()) return std::nullopt; - LLVM_DEBUG(dbgs() << "false\n"); + LLVM_DEBUG(dbgs() << *Const << "\n"); assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch"); BasicBlock *const FalseBlock = BranchI->getSuccessor(1); Cond = BranchI->getCondition(); @@ -417,6 +429,8 @@ std::vector MergedBlocks_; // The original entry block (before sorting); BasicBlock *EntryBlock_; + // Remember the predicate type of the chain. + ICmpInst::Predicate Predicate_; }; static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) { @@ -475,10 +489,13 @@ // Now look inside blocks to check for BCE comparisons. std::vector Comparisons; BaseIdentifier BaseId; + Value *const Baseline = Phi.getIncomingValueForBlock(Blocks[0]); + Predicate_ = CmpInst::BAD_ICMP_PREDICATE; for (BasicBlock *const Block : Blocks) { assert(Block && "invalid block"); - std::optional Comparison = visitCmpBlock( - Phi.getIncomingValueForBlock(Block), Block, Phi.getParent(), BaseId); + std::optional Comparison = + visitCmpBlock(Baseline, Predicate_, Phi.getIncomingValueForBlock(Block), + Block, Phi.getParent(), BaseId); if (!Comparison) { LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n"); return; @@ -602,7 +619,8 @@ BasicBlock *const InsertBefore, BasicBlock *const NextCmpBlock, PHINode &Phi, const TargetLibraryInfo &TLI, - AliasAnalysis &AA, DomTreeUpdater &DTU) { + AliasAnalysis &AA, DomTreeUpdater &DTU, + ICmpInst::Predicate Predicate) { assert(!Comparisons.empty() && "merging zero comparisons"); LLVMContext &Context = NextCmpBlock->getContext(); const BCECmpBlock &FirstCmp = Comparisons[0]; @@ -623,7 +641,7 @@ else Rhs = FirstCmp.Rhs().LoadI->getPointerOperand(); - Value *IsEqual = nullptr; + Value *ICmpValue = nullptr; LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons -> " << BB->getName() << "\n"); @@ -644,7 +662,7 @@ Value *const RhsLoad = Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs); // There are no blocks to merge, just do the comparison. - IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad); + ICmpValue = Builder.CreateICmp(Predicate, LhsLoad, RhsLoad); } else { const unsigned TotalSizeBits = std::accumulate( Comparisons.begin(), Comparisons.end(), 0u, @@ -660,8 +678,8 @@ Lhs, Rhs, ConstantInt::get(Builder.getIntNTy(SizeTBits), TotalSizeBits / 8), Builder, DL, &TLI); - IsEqual = Builder.CreateICmpEQ( - MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0)); + ICmpValue = Builder.CreateICmp( + Predicate, MemCmpCall, ConstantInt::get(Builder.getIntNTy(IntBits), 0)); } BasicBlock *const PhiBB = Phi.getParent(); @@ -669,11 +687,11 @@ if (NextCmpBlock == PhiBB) { // Continue to phi, passing it the comparison result. Builder.CreateBr(PhiBB); - Phi.addIncoming(IsEqual, BB); + Phi.addIncoming(ICmpValue, BB); DTU.applyUpdates({{DominatorTree::Insert, BB, PhiBB}}); } else { // Continue to next block if equal, exit to phi else. - Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB); + Builder.CreateCondBr(ICmpValue, NextCmpBlock, PhiBB); Phi.addIncoming(ConstantInt::getFalse(Context), BB); DTU.applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock}, {DominatorTree::Insert, BB, PhiBB}}); @@ -691,9 +709,11 @@ // so that the next block is always available to branch to. BasicBlock *InsertBefore = EntryBlock_; BasicBlock *NextCmpBlock = Phi_.getParent(); + assert(Predicate_ != CmpInst::BAD_ICMP_PREDICATE && + "Got the chain of comparisons"); for (const auto &Blocks : reverse(MergedBlocks_)) { InsertBefore = NextCmpBlock = mergeComparisons( - Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU); + Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU, Predicate_); } // Replace the original cmp chain with the new cmp chain by pointing all diff --git a/llvm/test/Transforms/MergeICmps/X86/pr59740.ll b/llvm/test/Transforms/MergeICmps/X86/pr59740.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/MergeICmps/X86/pr59740.ll @@ -0,0 +1,86 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=mergeicmps -verify-dom-info -S -mtriple=x86_64-unknown-unknown | FileCheck %s + +%struct.S = type { i8, i8, i8, i8 } + +define noundef i1 @_Z2neR1SS0_(ptr nocapture readonly align 1 dereferenceable(4) %s0, ptr nocapture readonly align 1 dereferenceable(4) %s1) { +; CHECK-LABEL: @_Z2neR1SS0_( +; CHECK-NEXT: "bb0+bb1+bb2+bb3": +; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[S0:%.*]], ptr [[S1:%.*]], i64 4) +; CHECK-NEXT: [[TMP0:%.*]] = icmp ne i32 [[MEMCMP]], 0 +; CHECK-NEXT: br label [[BB4:%.*]] +; CHECK: bb4: +; CHECK-NEXT: ret i1 [[TMP0]] +; +bb0: + %v0 = load i8, ptr %s0, align 1 + %v1 = load i8, ptr %s1, align 1 + %cmp0 = icmp eq i8 %v0, %v1 + br i1 %cmp0, label %bb1, label %bb4 + +bb1: ; preds = %bb0 + %s2 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 1 + %v2 = load i8, ptr %s2, align 1 + %s3 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 1 + %v3 = load i8, ptr %s3, align 1 + %cmp1 = icmp eq i8 %v2, %v3 + br i1 %cmp1, label %bb2, label %bb4 + +bb2: ; preds = %bb1 + %s4 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 2 + %v4 = load i8, ptr %s4, align 1 + %s5 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 2 + %v5 = load i8, ptr %s5, align 1 + %cmp2 = icmp eq i8 %v4, %v5 + br i1 %cmp2, label %bb3, label %bb4 + +bb3: ; preds = %bb2 + %s6 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 3 + %v6 = load i8, ptr %s6, align 1 + %s7 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 3 + %v7 = load i8, ptr %s7, align 1 + %cmp3 = icmp ne i8 %v6, %v7 + br label %bb4 + +bb4: ; preds = %bb0, %bb1, %bb2, %bb3 + %cmp = phi i1 [ true, %bb0 ], [ true, %bb1 ], [ true, %bb2 ], [ %cmp3, %bb3 ] + ret i1 %cmp +} + +; Negative test: Incorrect const value in PHI node +define noundef i1 @cmp_ne_incorrect_const(ptr nocapture readonly align 1 dereferenceable(4) %s0, ptr nocapture readonly align 1 dereferenceable(4) %s1) { +; CHECK-LABEL: @cmp_ne_incorrect_const( +; CHECK-NEXT: bb0: +; CHECK-NEXT: [[V0:%.*]] = load i8, ptr [[S0:%.*]], align 1 +; CHECK-NEXT: [[V1:%.*]] = load i8, ptr [[S1:%.*]], align 1 +; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i8 [[V0]], [[V1]] +; CHECK-NEXT: br i1 [[CMP0]], label [[BB1:%.*]], label [[BB2:%.*]] +; CHECK: bb1: +; CHECK-NEXT: [[S6:%.*]] = getelementptr inbounds [[STRUCT_S:%.*]], ptr [[S0]], i64 0, i32 1 +; CHECK-NEXT: [[V6:%.*]] = load i8, ptr [[S6]], align 1 +; CHECK-NEXT: [[S7:%.*]] = getelementptr inbounds [[STRUCT_S]], ptr [[S1]], i64 0, i32 1 +; CHECK-NEXT: [[V7:%.*]] = load i8, ptr [[S7]], align 1 +; CHECK-NEXT: [[CMP3:%.*]] = icmp ne i8 [[V6]], [[V7]] +; CHECK-NEXT: br label [[BB2]] +; CHECK: bb2: +; CHECK-NEXT: [[CMP:%.*]] = phi i1 [ false, [[BB0:%.*]] ], [ [[CMP3]], [[BB1]] ] +; CHECK-NEXT: ret i1 [[CMP]] +; +bb0: + %v0 = load i8, ptr %s0, align 1 + %v1 = load i8, ptr %s1, align 1 + %cmp0 = icmp eq i8 %v0, %v1 + br i1 %cmp0, label %bb1, label %bb2 + +bb1: ; preds = %bb0 + %s6 = getelementptr inbounds %struct.S, ptr %s0, i64 0, i32 1 + %v6 = load i8, ptr %s6, align 1 + %s7 = getelementptr inbounds %struct.S, ptr %s1, i64 0, i32 1 + %v7 = load i8, ptr %s7, align 1 + %cmp3 = icmp ne i8 %v6, %v7 + br label %bb2 + +bb2: ; preds = %bb0, %bb1 + %cmp = phi i1 [ false, %bb0 ], [ %cmp3, %bb1 ] + ret i1 %cmp +}