diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp --- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -229,6 +229,8 @@ InstructionSet BlockInsts; // The block requires splitting. bool RequireSplit = false; + // Original order of this block in the chain. + unsigned OrigOrder = 0; private: BCECmp Cmp; @@ -380,39 +382,83 @@ << Comparison.Rhs().BaseId << " + " << Comparison.Rhs().Offset << "\n"); LLVM_DEBUG(dbgs() << "\n"); + Comparison.OrigOrder = Comparisons.size(); Comparisons.push_back(std::move(Comparison)); } // A chain of comparisons. class BCECmpChain { - public: - BCECmpChain(const std::vector &Blocks, PHINode &Phi, - AliasAnalysis &AA); - - int size() const { return Comparisons_.size(); } +public: + using ContiguousBlocks = std::vector; -#ifdef MERGEICMPS_DOT_ON - void dump() const; -#endif // MERGEICMPS_DOT_ON + BCECmpChain(const std::vector &Blocks, PHINode &Phi, + AliasAnalysis &AA); bool simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, DomTreeUpdater &DTU); -private: - static bool IsContiguous(const BCECmpBlock &First, - const BCECmpBlock &Second) { - return First.Lhs().BaseId == Second.Lhs().BaseId && - First.Rhs().BaseId == Second.Rhs().BaseId && - First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset && - First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset; + bool atLeastOneMerged() const { + return any_of(MergedBlocks_, + [](const auto &Blocks) { return Blocks.size() > 1; }); } +private: PHINode &Phi_; - std::vector Comparisons_; + // The list of all blocks in the chain, grouped by contiguity. + std::vector MergedBlocks_; // The original entry block (before sorting); BasicBlock *EntryBlock_; }; +static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) { + return First.Lhs().BaseId == Second.Lhs().BaseId && + First.Rhs().BaseId == Second.Rhs().BaseId && + First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset && + First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset; +} + +static unsigned getMinOrigOrder(const BCECmpChain::ContiguousBlocks &Blocks) { + unsigned MinOrigOrder = std::numeric_limits::max(); + for (const BCECmpBlock &Block : Blocks) + MinOrigOrder = std::min(MinOrigOrder, Block.OrigOrder); + return MinOrigOrder; +} + +/// Given a chain of comparison blocks, groups the blocks into contiguous +/// ranges that can be merged together into a single comparison. +static std::vector +mergeBlocks(std::vector &&Blocks) { + std::vector MergedBlocks; + + // Sort to detect continuous offsets. + llvm::sort(Blocks, + [](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) { + return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) < + std::tie(RhsBlock.Lhs(), RhsBlock.Rhs()); + }); + + BCECmpChain::ContiguousBlocks *LastMergedBlock = nullptr; + for (BCECmpBlock &Block : Blocks) { + if (!LastMergedBlock || !areContiguous(LastMergedBlock->back(), Block)) { + MergedBlocks.emplace_back(); + LastMergedBlock = &MergedBlocks.back(); + } else { + LLVM_DEBUG(dbgs() << "Merging block " << Block.BB->getName() << " into " + << LastMergedBlock->back().BB->getName() << "\n"); + } + LastMergedBlock->push_back(std::move(Block)); + } + + // While we allow reordering for merging, do not reorder unmerged comparisons. + // Doing so may introduce branch on poison. + llvm::sort(MergedBlocks, [](const BCECmpChain::ContiguousBlocks &LhsBlocks, + const BCECmpChain::ContiguousBlocks &RhsBlocks) { + return getMinOrigOrder(LhsBlocks) < getMinOrigOrder(RhsBlocks); + }); + + return MergedBlocks; +} + BCECmpChain::BCECmpChain(const std::vector &Blocks, PHINode &Phi, AliasAnalysis &AA) : Phi_(Phi) { @@ -492,46 +538,8 @@ return; } EntryBlock_ = Comparisons[0].BB; - Comparisons_ = std::move(Comparisons); -#ifdef MERGEICMPS_DOT_ON - errs() << "BEFORE REORDERING:\n\n"; - dump(); -#endif // MERGEICMPS_DOT_ON - // Reorder blocks by LHS. We can do that without changing the - // semantics because we are only accessing dereferencable memory. - llvm::sort(Comparisons_, - [](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) { - return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) < - std::tie(RhsBlock.Lhs(), RhsBlock.Rhs()); - }); -#ifdef MERGEICMPS_DOT_ON - errs() << "AFTER REORDERING:\n\n"; - dump(); -#endif // MERGEICMPS_DOT_ON -} - -#ifdef MERGEICMPS_DOT_ON -void BCECmpChain::dump() const { - errs() << "digraph dag {\n"; - errs() << " graph [bgcolor=transparent];\n"; - errs() << " node [color=black,style=filled,fillcolor=lightyellow];\n"; - errs() << " edge [color=black];\n"; - for (size_t I = 0; I < Comparisons_.size(); ++I) { - const auto &Comparison = Comparisons_[I]; - errs() << " \"" << I << "\" [label=\"%" - << Comparison.Lhs().Base()->getName() << " + " - << Comparison.Lhs().Offset << " == %" - << Comparison.Rhs().Base()->getName() << " + " - << Comparison.Rhs().Offset << " (" << (Comparison.SizeBits() / 8) - << " bytes)\"];\n"; - const Value *const Val = Phi_.getIncomingValueForBlock(Comparison.BB); - if (I > 0) errs() << " \"" << (I - 1) << "\" -> \"" << I << "\";\n"; - errs() << " \"" << I << "\" -> \"Phi\" [label=\"" << *Val << "\"];\n"; - } - errs() << " \"Phi\" [label=\"Phi\"];\n"; - errs() << "}\n\n"; + MergedBlocks_ = mergeBlocks(std::move(Comparisons)); } -#endif // MERGEICMPS_DOT_ON namespace { @@ -655,47 +663,20 @@ bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, DomTreeUpdater &DTU) { - assert(Comparisons_.size() >= 2 && "simplifying trivial BCECmpChain"); - // First pass to check if there is at least one merge. If not, we don't do - // anything and we keep analysis passes intact. - const auto AtLeastOneMerged = [this]() { - for (size_t I = 1; I < Comparisons_.size(); ++I) { - if (IsContiguous(Comparisons_[I - 1], Comparisons_[I])) - return true; - } - return false; - }; - if (!AtLeastOneMerged()) - return false; - + assert(atLeastOneMerged() && "simplifying trivial BCECmpChain"); LLVM_DEBUG(dbgs() << "Simplifying comparison chain starting at block " << EntryBlock_->getName() << "\n"); // Effectively merge blocks. We go in the reverse direction from the phi block // so that the next block is always available to branch to. - const auto mergeRange = [this, &TLI, &AA, &DTU](int I, int Num, - BasicBlock *InsertBefore, - BasicBlock *Next) { - return mergeComparisons(makeArrayRef(Comparisons_).slice(I, Num), - InsertBefore, Next, Phi_, TLI, AA, DTU); - }; int NumMerged = 1; + BasicBlock *InsertBefore = EntryBlock_; BasicBlock *NextCmpBlock = Phi_.getParent(); - for (int I = static_cast(Comparisons_.size()) - 2; I >= 0; --I) { - if (IsContiguous(Comparisons_[I], Comparisons_[I + 1])) { - LLVM_DEBUG(dbgs() << "Merging block " << Comparisons_[I].BB->getName() - << " into " << Comparisons_[I + 1].BB->getName() - << "\n"); - ++NumMerged; - } else { - NextCmpBlock = mergeRange(I + 1, NumMerged, NextCmpBlock, NextCmpBlock); - NumMerged = 1; - } + for (const auto &Blocks : reverse(MergedBlocks_)) { + NumMerged += Blocks.size() - 1; + InsertBefore = NextCmpBlock = mergeComparisons( + Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU); } - // Insert the entry block for the new chain before the old entry block. - // If the old entry block was the function entry, this ensures that the new - // entry can become the function entry. - NextCmpBlock = mergeRange(0, NumMerged, EntryBlock_, NextCmpBlock); // Replace the original cmp chain with the new cmp chain by pointing all // predecessors of EntryBlock_ to NextCmpBlock instead. This makes all cmp @@ -723,13 +704,16 @@ // Delete merged blocks. This also removes incoming values in phi. SmallVector DeadBlocks; - for (auto &Cmp : Comparisons_) { - LLVM_DEBUG(dbgs() << "Deleting merged block " << Cmp.BB->getName() << "\n"); - DeadBlocks.push_back(Cmp.BB); + for (const auto &Blocks : MergedBlocks_) { + for (const BCECmpBlock &Block : Blocks) { + LLVM_DEBUG(dbgs() << "Deleting merged block " << Block.BB->getName() + << "\n"); + DeadBlocks.push_back(Block.BB); + } } DeleteDeadBlocks(DeadBlocks, &DTU); - Comparisons_.clear(); + MergedBlocks_.clear(); return true; } @@ -829,8 +813,8 @@ if (Blocks.empty()) return false; BCECmpChain CmpChain(Blocks, Phi, AA); - if (CmpChain.size() < 2) { - LLVM_DEBUG(dbgs() << "skip: only one compare block\n"); + if (!CmpChain.atLeastOneMerged()) { + LLVM_DEBUG(dbgs() << "skip: nothing merged\n"); return false; } diff --git a/llvm/test/Transforms/MergeICmps/X86/entry-block-shuffled-2.ll b/llvm/test/Transforms/MergeICmps/X86/entry-block-shuffled-2.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/MergeICmps/X86/entry-block-shuffled-2.ll @@ -0,0 +1,70 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -mergeicmps < %s | FileCheck %s + +target triple = "x86_64-unknown-linux-gnu" + +%"struct.a::c" = type { i32, i32*, i8* } + +; The entry block cannot be merged as the comparison is not continuous. +; While it compares the highest address, it should not be moved after the +; other comparisons, as that would make the allocas non-dominating. + +define i1 @test() { +; CHECK-LABEL: @test( +; CHECK-NEXT: "land.lhs.true+entry": +; CHECK-NEXT: [[H:%.*]] = alloca %"struct.a::c", align 8 +; CHECK-NEXT: [[I:%.*]] = alloca %"struct.a::c", align 8 +; CHECK-NEXT: call void @init(%"struct.a::c"* [[H]]) +; CHECK-NEXT: call void @init(%"struct.a::c"* [[I]]) +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds %"struct.a::c", %"struct.a::c"* [[H]], i64 0, i32 1 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds %"struct.a::c", %"struct.a::c"* [[I]], i64 0, i32 1 +; CHECK-NEXT: [[CSTR:%.*]] = bitcast i32** [[TMP0]] to i8* +; CHECK-NEXT: [[CSTR2:%.*]] = bitcast i32** [[TMP1]] to i8* +; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(i8* [[CSTR]], i8* [[CSTR2]], i64 16) +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[MEMCMP]], 0 +; CHECK-NEXT: br i1 [[TMP2]], label [[LAND_RHS1:%.*]], label [[LAND_END:%.*]] +; CHECK: land.rhs1: +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds %"struct.a::c", %"struct.a::c"* [[H]], i64 0, i32 0 +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds %"struct.a::c", %"struct.a::c"* [[I]], i64 0, i32 0 +; CHECK-NEXT: [[TMP5:%.*]] = load i32, i32* [[TMP3]], align 4 +; CHECK-NEXT: [[TMP6:%.*]] = load i32, i32* [[TMP4]], align 4 +; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i32 [[TMP5]], [[TMP6]] +; CHECK-NEXT: br label [[LAND_END]] +; CHECK: land.end: +; CHECK-NEXT: [[V9:%.*]] = phi i1 [ [[TMP7]], [[LAND_RHS1]] ], [ false, %"land.lhs.true+entry" ] +; CHECK-NEXT: ret i1 [[V9]] +; +entry: + %h = alloca %"struct.a::c", align 8 + %i = alloca %"struct.a::c", align 8 + call void @init(%"struct.a::c"* %h) + call void @init(%"struct.a::c"* %i) + %e = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %h, i64 0, i32 2 + %v3 = load i8*, i8** %e, align 8 + %e2 = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %i, i64 0, i32 2 + %v4 = load i8*, i8** %e2, align 8 + %cmp = icmp eq i8* %v3, %v4 + br i1 %cmp, label %land.lhs.true, label %land.end + +land.lhs.true: ; preds = %entry + %d = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %h, i64 0, i32 1 + %v5 = load i32*, i32** %d, align 8 + %d3 = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %i, i64 0, i32 1 + %v6 = load i32*, i32** %d3, align 8 + %cmp4 = icmp eq i32* %v5, %v6 + br i1 %cmp4, label %land.rhs, label %land.end + +land.rhs: ; preds = %land.lhs.true + %j = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %h, i64 0, i32 0 + %v7 = load i32, i32* %j, align 8 + %j5 = getelementptr inbounds %"struct.a::c", %"struct.a::c"* %i, i64 0, i32 0 + %v8 = load i32, i32* %j5, align 8 + %cmp6 = icmp eq i32 %v7, %v8 + br label %land.end + +land.end: ; preds = %land.rhs, %land.lhs.true, %entry + %v9 = phi i1 [ false, %land.lhs.true ], [ false, %entry ], [ %cmp6, %land.rhs ] + ret i1 %v9 +} + +declare void @init(%"struct.a::c"*) diff --git a/llvm/test/Transforms/MergeICmps/X86/entry-block-shuffled.ll b/llvm/test/Transforms/MergeICmps/X86/entry-block-shuffled.ll --- a/llvm/test/Transforms/MergeICmps/X86/entry-block-shuffled.ll +++ b/llvm/test/Transforms/MergeICmps/X86/entry-block-shuffled.ll @@ -9,20 +9,20 @@ define zeroext i1 @opeq1( ; CHECK-LABEL: @opeq1( -; CHECK-NEXT: "land.rhs.i+land.rhs.i.2": -; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [[S:%.*]], %S* [[A:%.*]], i64 0, i32 0 -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds [[S]], %S* [[B:%.*]], i64 0, i32 0 -; CHECK-NEXT: [[CSTR:%.*]] = bitcast i32* [[TMP0]] to i8* -; CHECK-NEXT: [[CSTR3:%.*]] = bitcast i32* [[TMP1]] to i8* -; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(i8* [[CSTR]], i8* [[CSTR3]], i64 8) -; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[MEMCMP]], 0 -; CHECK-NEXT: br i1 [[TMP2]], label [[ENTRY2:%.*]], label [[OPEQ1_EXIT:%.*]] -; CHECK: entry2: -; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds [[S]], %S* [[A]], i64 0, i32 3 -; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[S]], %S* [[B]], i64 0, i32 2 -; CHECK-NEXT: [[TMP5:%.*]] = load i32, i32* [[TMP3]], align 4 -; CHECK-NEXT: [[TMP6:%.*]] = load i32, i32* [[TMP4]], align 4 -; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i32 [[TMP5]], [[TMP6]] +; CHECK-NEXT: entry3: +; CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds [[S:%.*]], %S* [[A:%.*]], i64 0, i32 3 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds [[S]], %S* [[B:%.*]], i64 0, i32 2 +; CHECK-NEXT: [[TMP2:%.*]] = load i32, i32* [[TMP0]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = load i32, i32* [[TMP1]], align 4 +; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i32 [[TMP2]], [[TMP3]] +; CHECK-NEXT: br i1 [[TMP4]], label %"land.rhs.i+land.rhs.i.2", label [[OPEQ1_EXIT:%.*]] +; CHECK: "land.rhs.i+land.rhs.i.2": +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds [[S]], %S* [[A]], i64 0, i32 0 +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[S]], %S* [[B]], i64 0, i32 0 +; CHECK-NEXT: [[CSTR:%.*]] = bitcast i32* [[TMP5]] to i8* +; CHECK-NEXT: [[CSTR2:%.*]] = bitcast i32* [[TMP6]] to i8* +; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(i8* [[CSTR]], i8* [[CSTR2]], i64 8) +; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i32 [[MEMCMP]], 0 ; CHECK-NEXT: br i1 [[TMP7]], label [[LAND_RHS_I_31:%.*]], label [[OPEQ1_EXIT]] ; CHECK: land.rhs.i.31: ; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[S]], %S* [[A]], i64 0, i32 3 @@ -32,7 +32,7 @@ ; CHECK-NEXT: [[TMP12:%.*]] = icmp eq i32 [[TMP10]], [[TMP11]] ; CHECK-NEXT: br label [[OPEQ1_EXIT]] ; CHECK: opeq1.exit: -; CHECK-NEXT: [[TMP13:%.*]] = phi i1 [ [[TMP12]], [[LAND_RHS_I_31]] ], [ false, [[ENTRY2]] ], [ false, %"land.rhs.i+land.rhs.i.2" ] +; CHECK-NEXT: [[TMP13:%.*]] = phi i1 [ [[TMP12]], [[LAND_RHS_I_31]] ], [ false, %"land.rhs.i+land.rhs.i.2" ], [ false, [[ENTRY3:%.*]] ] ; CHECK-NEXT: ret i1 [[TMP13]] ; %S* nocapture readonly dereferenceable(16) %a,