Index: llvm/trunk/include/llvm/Analysis/BitSetUtils.h =================================================================== --- llvm/trunk/include/llvm/Analysis/BitSetUtils.h +++ llvm/trunk/include/llvm/Analysis/BitSetUtils.h @@ -0,0 +1,38 @@ +//===- BitSetUtils.h - Utilities related to pointer bitsets ------*- C++ -*-==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains functions that make it easier to manipulate bitsets for +// devirtualization. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_BITSETUTILS_H +#define LLVM_ANALYSIS_BITSETUTILS_H + +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/CallSite.h" + +namespace llvm { + +// A call site that could be devirtualized. +struct DevirtCallSite { + // The offset from the address point to the virtual function. + uint64_t Offset; + // The call site itself. + CallSite CS; +}; + +// Given a call to the intrinsic @llvm.bitset.test, find all devirtualizable +// call sites based on the call and return them in DevirtCalls. +void findDevirtualizableCalls(SmallVectorImpl &DevirtCalls, + SmallVectorImpl &Assumes, + CallInst *CI); +} + +#endif Index: llvm/trunk/lib/Analysis/BitSetUtils.cpp =================================================================== --- llvm/trunk/lib/Analysis/BitSetUtils.cpp +++ llvm/trunk/lib/Analysis/BitSetUtils.cpp @@ -0,0 +1,82 @@ +//===- BitSetUtils.cpp - Utilities related to pointer bitsets -------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains functions that make it easier to manipulate bitsets for +// devirtualization. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/BitSetUtils.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" + +using namespace llvm; + +// Search for virtual calls that call FPtr and add them to DevirtCalls. +static void +findCallsAtConstantOffset(SmallVectorImpl &DevirtCalls, + Value *FPtr, uint64_t Offset) { + for (const Use &U : FPtr->uses()) { + Value *User = U.getUser(); + if (isa(User)) { + findCallsAtConstantOffset(DevirtCalls, User, Offset); + } else if (auto CI = dyn_cast(User)) { + DevirtCalls.push_back({Offset, CI}); + } else if (auto II = dyn_cast(User)) { + DevirtCalls.push_back({Offset, II}); + } + } +} + +// Search for virtual calls that load from VPtr and add them to DevirtCalls. +static void +findLoadCallsAtConstantOffset(Module *M, + SmallVectorImpl &DevirtCalls, + Value *VPtr, uint64_t Offset) { + for (const Use &U : VPtr->uses()) { + Value *User = U.getUser(); + if (isa(User)) { + findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset); + } else if (isa(User)) { + findCallsAtConstantOffset(DevirtCalls, User, Offset); + } else if (auto GEP = dyn_cast(User)) { + // Take into account the GEP offset. + if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) { + SmallVector Indices(GEP->op_begin() + 1, GEP->op_end()); + uint64_t GEPOffset = M->getDataLayout().getIndexedOffsetInType( + GEP->getSourceElementType(), Indices); + findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset + GEPOffset); + } + } + } +} + +void llvm::findDevirtualizableCalls( + SmallVectorImpl &DevirtCalls, + SmallVectorImpl &Assumes, CallInst *CI) { + assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::bitset_test); + + Module *M = CI->getParent()->getParent()->getParent(); + + // Find llvm.assume intrinsics for this llvm.bitset.test call. + for (const Use &CIU : CI->uses()) { + auto AssumeCI = dyn_cast(CIU.getUser()); + if (AssumeCI) { + Function *F = AssumeCI->getCalledFunction(); + if (F && F->getIntrinsicID() == Intrinsic::assume) + Assumes.push_back(AssumeCI); + } + } + + // If we found any, search for virtual calls based on %p and add them to + // DevirtCalls. + if (!Assumes.empty()) + findLoadCallsAtConstantOffset(M, DevirtCalls, + CI->getArgOperand(0)->stripPointerCasts(), 0); +} Index: llvm/trunk/lib/Analysis/CMakeLists.txt =================================================================== --- llvm/trunk/lib/Analysis/CMakeLists.txt +++ llvm/trunk/lib/Analysis/CMakeLists.txt @@ -5,6 +5,7 @@ Analysis.cpp AssumptionCache.cpp BasicAliasAnalysis.cpp + BitSetUtils.cpp BlockFrequencyInfo.cpp BlockFrequencyInfoImpl.cpp BranchProbabilityInfo.cpp Index: llvm/trunk/lib/Transforms/IPO/WholeProgramDevirt.cpp =================================================================== --- llvm/trunk/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ llvm/trunk/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" +#include "llvm/Analysis/BitSetUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -231,10 +232,6 @@ : M(M), Int8Ty(Type::getInt8Ty(M.getContext())), Int8PtrTy(Type::getInt8PtrTy(M.getContext())), Int32Ty(Type::getInt32Ty(M.getContext())) {} - void findLoadCallsAtConstantOffset(Metadata *BitSet, Value *Ptr, - uint64_t Offset, Value *VTable); - void findCallsAtConstantOffset(Metadata *BitSet, Value *Ptr, uint64_t Offset, - Value *VTable); void buildBitSets(std::vector &Bits, DenseMap> &BitSets); @@ -283,43 +280,6 @@ return new WholeProgramDevirt; } -// Search for virtual calls that call FPtr and add them to CallSlots. -void DevirtModule::findCallsAtConstantOffset(Metadata *BitSet, Value *FPtr, - uint64_t Offset, Value *VTable) { - for (const Use &U : FPtr->uses()) { - Value *User = U.getUser(); - if (isa(User)) { - findCallsAtConstantOffset(BitSet, User, Offset, VTable); - } else if (auto CI = dyn_cast(User)) { - CallSlots[{BitSet, Offset}].push_back({VTable, CI}); - } else if (auto II = dyn_cast(User)) { - CallSlots[{BitSet, Offset}].push_back({VTable, II}); - } - } -} - -// Search for virtual calls that load from VPtr and add them to CallSlots. -void DevirtModule::findLoadCallsAtConstantOffset(Metadata *BitSet, Value *VPtr, - uint64_t Offset, - Value *VTable) { - for (const Use &U : VPtr->uses()) { - Value *User = U.getUser(); - if (isa(User)) { - findLoadCallsAtConstantOffset(BitSet, User, Offset, VTable); - } else if (isa(User)) { - findCallsAtConstantOffset(BitSet, User, Offset, VTable); - } else if (auto GEP = dyn_cast(User)) { - // Take into account the GEP offset. - if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) { - SmallVector Indices(GEP->op_begin() + 1, GEP->op_end()); - uint64_t GEPOffset = M.getDataLayout().getIndexedOffsetInType( - GEP->getSourceElementType(), Indices); - findLoadCallsAtConstantOffset(BitSet, User, Offset + GEPOffset, VTable); - } - } - } -} - void DevirtModule::buildBitSets( std::vector &Bits, DenseMap> &BitSets) { @@ -674,22 +634,23 @@ if (!CI) continue; - // Find llvm.assume intrinsics for this llvm.bitset.test call. + // Search for virtual calls based on %p and add them to DevirtCalls. + SmallVector DevirtCalls; SmallVector Assumes; - for (const Use &CIU : CI->uses()) { - auto AssumeCI = dyn_cast(CIU.getUser()); - if (AssumeCI && AssumeCI->getCalledValue() == AssumeFunc) - Assumes.push_back(AssumeCI); - } + findDevirtualizableCalls(DevirtCalls, Assumes, CI); - // If we found any, search for virtual calls based on %p and add them to - // CallSlots. + // If we found any, add them to CallSlots. Only do this if we haven't seen + // the vtable pointer before, as it may have been CSE'd with pointers from + // other call sites, and we don't want to process call sites multiple times. if (!Assumes.empty()) { Metadata *BitSet = cast(CI->getArgOperand(1))->getMetadata(); Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); - if (SeenPtrs.insert(Ptr).second) - findLoadCallsAtConstantOffset(BitSet, Ptr, 0, CI->getArgOperand(0)); + if (SeenPtrs.insert(Ptr).second) { + for (DevirtCallSite Call : DevirtCalls) + CallSlots[{BitSet, Call.Offset}].push_back( + {CI->getArgOperand(0), Call.CS}); + } } // We no longer need the assumes or the bitset test.