Index: llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp =================================================================== --- llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -35,6 +35,8 @@ #include "llvm/ADT/iterator_range.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" @@ -63,6 +65,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/Utils/Evaluator.h" #include #include @@ -326,6 +329,7 @@ struct DevirtModule { Module &M; + function_ref AARGetter; PassSummaryAction Action; ModuleSummaryIndex *Summary; @@ -349,8 +353,9 @@ // true. std::map NumUnsafeUsesForTypeTest; - DevirtModule(Module &M, PassSummaryAction Action, ModuleSummaryIndex *Summary) - : M(M), Action(Action), Summary(Summary), + DevirtModule(Module &M, function_ref AARGetter, + PassSummaryAction Action, ModuleSummaryIndex *Summary) + : M(M), AARGetter(AARGetter), Action(Action), Summary(Summary), Int8Ty(Type::getInt8Ty(M.getContext())), Int8PtrTy(Type::getInt8PtrTy(M.getContext())), Int32Ty(Type::getInt32Ty(M.getContext())), @@ -401,7 +406,8 @@ // Lower the module using the action and summary passed as command line // arguments. For testing purposes only. - static bool runForTesting(Module &M); + static bool runForTesting(Module &M, + function_ref AARGetter); }; struct WholeProgramDevirt : public ModulePass { @@ -425,15 +431,24 @@ if (skipModule(M)) return false; if (UseCommandLine) - return DevirtModule::runForTesting(M); - return DevirtModule(M, Action, Summary).run(); + return DevirtModule::runForTesting(M, LegacyAARGetter(*this)); + return DevirtModule(M, LegacyAARGetter(*this), Action, Summary).run(); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.addRequired(); } }; } // end anonymous namespace -INITIALIZE_PASS(WholeProgramDevirt, "wholeprogramdevirt", - "Whole program devirtualization", false, false) +INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt", + "Whole program devirtualization", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt", + "Whole program devirtualization", false, false) char WholeProgramDevirt::ID = 0; ModulePass *llvm::createWholeProgramDevirtPass(PassSummaryAction Action, @@ -442,13 +457,18 @@ } PreservedAnalyses WholeProgramDevirtPass::run(Module &M, - ModuleAnalysisManager &) { - if (!DevirtModule(M, PassSummaryAction::None, nullptr).run()) + ModuleAnalysisManager &AM) { + auto &FAM = AM.getResult(M).getManager(); + auto AARGetter = [&](Function &F) -> AAResults & { + return FAM.getResult(F); + }; + if (!DevirtModule(M, AARGetter, PassSummaryAction::None, nullptr).run()) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } -bool DevirtModule::runForTesting(Module &M) { +bool DevirtModule::runForTesting( + Module &M, function_ref AARGetter) { ModuleSummaryIndex Summary; // Handle the command-line summary arguments. This code is for testing @@ -464,7 +484,7 @@ ExitOnErr(errorCodeToError(In.error())); } - bool Changed = DevirtModule(M, ClSummaryAction, &Summary).run(); + bool Changed = DevirtModule(M, AARGetter, ClSummaryAction, &Summary).run(); if (!ClWriteSummary.empty()) { ExitOnError ExitOnErr( @@ -754,8 +774,17 @@ // Make sure that each function is defined, does not access memory, takes at // least one argument, does not use its first argument (which we assume is // 'this'), and has the same return type. + // + // Note that we test whether this copy of the function is readnone, rather + // than testing function attributes, which must hold for any copy of the + // function, even a less optimized version substituted at link time. This is + // sound because the virtual constant propagation optimizations effectively + // inline all implementations of the virtual function into each call site, + // rather than using function attributes to perform local optimization. for (VirtualCallTarget &Target : TargetsForSlot) { - if (Target.Fn->isDeclaration() || !Target.Fn->doesNotAccessMemory() || + if (Target.Fn->isDeclaration() || + computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) != + MAK_ReadNone || Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() || Target.Fn->getReturnType() != RetType) return false; Index: llvm/test/Transforms/WholeProgramDevirt/vcp-accesses-memory.ll =================================================================== --- llvm/test/Transforms/WholeProgramDevirt/vcp-accesses-memory.ll +++ llvm/test/Transforms/WholeProgramDevirt/vcp-accesses-memory.ll @@ -1,21 +1,34 @@ ; RUN: opt -S -wholeprogramdevirt %s | FileCheck %s +; RUN: opt -S -passes=wholeprogramdevirt %s | FileCheck %s target datalayout = "e-p:64:64" target triple = "x86_64-unknown-linux-gnu" -@vt1 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf1 to i8*)], !type !0 -@vt2 = constant [1 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf2 to i8*)], !type !0 +@vt1 = constant [2 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf1a to i8*), i8* bitcast (i32 (i8*, i32)* @vf1b to i8*)], !type !0 +@vt2 = constant [2 x i8*] [i8* bitcast (i32 (i8*, i32)* @vf2a to i8*), i8* bitcast (i32 (i8*, i32)* @vf2b to i8*)], !type !0 -define i32 @vf1(i8* %this, i32 %arg) { +@sink = external global i32 + +define i32 @vf1a(i8* %this, i32 %arg) { + store i32 %arg, i32* @sink ret i32 %arg } -define i32 @vf2(i8* %this, i32 %arg) { +define i32 @vf2a(i8* %this, i32 %arg) { + store i32 %arg, i32* @sink ret i32 %arg } -; CHECK: define i32 @call -define i32 @call(i8* %obj) { +define i32 @vf1b(i8* %this, i32 %arg) { + ret i32 %arg +} + +define i32 @vf2b(i8* %this, i32 %arg) { + ret i32 %arg +} + +; CHECK: define i32 @call1 +define i32 @call1(i8* %obj) { %vtableptr = bitcast i8* %obj to [1 x i8*]** %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr %vtablei8 = bitcast [1 x i8*]* %vtable to i8* @@ -29,6 +42,21 @@ ret i32 %result } +; CHECK: define i32 @call2 +define i32 @call2(i8* %obj) { + %vtableptr = bitcast i8* %obj to [1 x i8*]** + %vtable = load [1 x i8*]*, [1 x i8*]** %vtableptr + %vtablei8 = bitcast [1 x i8*]* %vtable to i8* + %p = call i1 @llvm.type.test(i8* %vtablei8, metadata !"typeid") + call void @llvm.assume(i1 %p) + %fptrptr = getelementptr [1 x i8*], [1 x i8*]* %vtable, i32 0, i32 1 + %fptr = load i8*, i8** %fptrptr + %fptr_casted = bitcast i8* %fptr to i32 (i8*, i32)* + %result = call i32 %fptr_casted(i8* %obj, i32 1) + ; CHECK: ret i32 1 + ret i32 %result +} + declare i1 @llvm.type.test(i8*, metadata) declare void @llvm.assume(i1)