diff --git a/llvm/include/llvm/Analysis/CodeMetrics.h b/llvm/include/llvm/Analysis/CodeMetrics.h --- a/llvm/include/llvm/Analysis/CodeMetrics.h +++ b/llvm/include/llvm/Analysis/CodeMetrics.h @@ -87,6 +87,11 @@ /// assume or similar intrinsics in the function). static void collectEphemeralValues(const Function *L, AssumptionCache *AC, SmallPtrSetImpl &EphValues); + + /// Collect a basic block's ephemeral values (those used only by an + /// assume or similar intrinsics in the basic block). + static void collectEphemeralValues(const BasicBlock *BB, AssumptionCache *AC, + SmallPtrSetImpl &EphValues); }; } diff --git a/llvm/lib/Analysis/CodeMetrics.cpp b/llvm/lib/Analysis/CodeMetrics.cpp --- a/llvm/lib/Analysis/CodeMetrics.cpp +++ b/llvm/lib/Analysis/CodeMetrics.cpp @@ -111,6 +111,26 @@ completeEphemeralValues(Visited, Worklist, EphValues); } +void CodeMetrics::collectEphemeralValues( + const BasicBlock *BB, AssumptionCache *AC, + SmallPtrSetImpl &EphValues) { + SmallPtrSet Visited; + SmallVector Worklist; + + for (auto &AssumeVH : AC->assumptions()) { + if (!AssumeVH) + continue; + Instruction *I = cast(AssumeVH); + if (I->getParent() != BB) + continue; + + if (EphValues.insert(I).second) + appendSpeculatableOperands(I, Visited, Worklist); + } + + completeEphemeralValues(Visited, Worklist, EphValues); +} + /// Fill in the current structure with information gleaned from the specified /// block. void CodeMetrics::analyzeBasicBlock( diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -25,6 +25,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/GuardUtils.h" @@ -2427,9 +2428,14 @@ } /// Return true if we can thread a branch across this block. -static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB) { +static bool BlockIsSimpleEnoughToThreadThrough(BasicBlock *BB, + AssumptionCache *AC) { int Size = 0; + SmallPtrSet EphValues; + if (AC) + CodeMetrics::collectEphemeralValues(BB, AC, EphValues); + for (Instruction &I : BB->instructionsWithoutDebug()) { if (Size > MaxSmallBlockSize) return false; // Don't clone large BB's. @@ -2440,8 +2446,8 @@ return false; // We will delete Phis while threading, so Phis should not be accounted in - // block's size - if (!isa(I)) + // block's size. Ditto for ephemeral values which will also be deleted. + if (!isa(I) && !EphValues.count(&I)) ++Size; // We can only support instructions that do not define values that are @@ -2477,7 +2483,7 @@ } // Now we know that this block has multiple preds and two succs. - if (!BlockIsSimpleEnoughToThreadThrough(BB)) + if (!BlockIsSimpleEnoughToThreadThrough(BB, AC)) return false; // Okay, this is a simple enough basic block. See if any phi values are @@ -3547,7 +3553,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, DomTreeUpdater *DTU, const DataLayout &DL, - const TargetTransformInfo &TTI) { + const TargetTransformInfo &TTI, + AssumptionCache *AC) { assert(PBI->isConditional() && BI->isConditional()); BasicBlock *BB = BI->getParent(); @@ -3569,7 +3576,7 @@ // Otherwise, if there are multiple predecessors, insert a PHI that merges // in the constant and simplify the block result. Subsequent passes of // simplifycfg will thread the block. - if (BlockIsSimpleEnoughToThreadThrough(BB)) { + if (BlockIsSimpleEnoughToThreadThrough(BB, AC)) { pred_iterator PB = pred_begin(BB), PE = pred_end(BB); PHINode *NewPN = PHINode::Create( Type::getInt1Ty(BB->getContext()), std::distance(PB, PE), @@ -6496,7 +6503,7 @@ for (BasicBlock *Pred : predecessors(BB)) if (BranchInst *PBI = dyn_cast(Pred->getTerminator())) if (PBI != BI && PBI->isConditional()) - if (SimplifyCondBranchToCondBranch(PBI, BI, DTU, DL, TTI)) + if (SimplifyCondBranchToCondBranch(PBI, BI, DTU, DL, TTI, Options.AC)) return requestResimplify(); // Look for diamond patterns. diff --git a/llvm/test/Transforms/SimplifyCFG/unprofitable-pr.ll b/llvm/test/Transforms/SimplifyCFG/unprofitable-pr.ll --- a/llvm/test/Transforms/SimplifyCFG/unprofitable-pr.ll +++ b/llvm/test/Transforms/SimplifyCFG/unprofitable-pr.ll @@ -1,10 +1,11 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt -simplifycfg -simplifycfg-require-and-preserve-domtree=1 -simplifycfg-max-small-block-size=10 -S < %s | FileCheck %s -; RUN: opt -passes=simplify-cfg -simplifycfg-max-small-block-size=10 -S < %s | FileCheck %s +; RUN: opt -simplifycfg -simplifycfg-require-and-preserve-domtree=1 -simplifycfg-max-small-block-size=6 -S < %s | FileCheck %s +; RUN: opt -passes=simplify-cfg -simplifycfg-max-small-block-size=6 -S < %s | FileCheck %s target datalayout = "e-p:64:64-p5:32:32-A5" declare void @llvm.assume(i1) +declare i1 @llvm.type.test(i8*, metadata) nounwind readnone define void @test_01(i1 %c, i64* align 1 %ptr) local_unnamed_addr #0 { ; CHECK-LABEL: @test_01( @@ -165,3 +166,61 @@ store volatile i64 3, i64* %ptr, align 8 ret void } + +; Try the max block size for PRE again but with the bitcast/type test/assume +; sequence used for whole program devirt. +define void @test_04(i1 %c, i64* align 1 %ptr, [3 x i8*]* %vtable) local_unnamed_addr #0 { +; CHECK-LABEL: @test_04( +; CHECK-NEXT: br i1 [[C:%.*]], label [[TRUE2_CRITEDGE:%.*]], label [[FALSE1:%.*]] +; CHECK: false1: +; CHECK-NEXT: store volatile i64 1, i64* [[PTR:%.*]], align 4 +; CHECK-NEXT: [[VTABLE:%.*]] = bitcast [3 x i8*]* %vtable to i8* +; CHECK-NEXT: [[P:%.*]] = call i1 @llvm.type.test(i8* [[VTABLE]], metadata !"foo") +; CHECK-NEXT: tail call void @llvm.assume(i1 [[P]]) +; CHECK-NEXT: store volatile i64 0, i64* [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, i64* [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, i64* [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, i64* [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, i64* [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, i64* [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 3, i64* [[PTR]], align 8 +; CHECK-NEXT: ret void +; CHECK: true2.critedge: +; CHECK-NEXT: [[VTABLE:%.*]] = bitcast [3 x i8*]* %vtable to i8* +; CHECK-NEXT: [[P:%.*]] = call i1 @llvm.type.test(i8* [[VTABLE]], metadata !"foo") +; CHECK-NEXT: tail call void @llvm.assume(i1 [[P]]) +; CHECK-NEXT: store volatile i64 0, i64* [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, i64* [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, i64* [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, i64* [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, i64* [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 -1, i64* [[PTR]], align 8 +; CHECK-NEXT: store volatile i64 2, i64* [[PTR]], align 8 +; CHECK-NEXT: ret void +; + br i1 %c, label %true1, label %false1 + +true1: ; preds = %false1, %0 + %vtablei8 = bitcast [3 x i8*]* %vtable to i8* + %p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"foo") + tail call void @llvm.assume(i1 %p) + store volatile i64 0, i64* %ptr, align 8 + store volatile i64 -1, i64* %ptr, align 8 + store volatile i64 -1, i64* %ptr, align 8 + store volatile i64 -1, i64* %ptr, align 8 + store volatile i64 -1, i64* %ptr, align 8 + store volatile i64 -1, i64* %ptr, align 8 + br i1 %c, label %true2, label %false2 + +false1: ; preds = %0 + store volatile i64 1, i64* %ptr, align 4 + br label %true1 + +true2: ; preds = %true1 + store volatile i64 2, i64* %ptr, align 8 + ret void + +false2: ; preds = %true1 + store volatile i64 3, i64* %ptr, align 8 + ret void +}