Index: include/llvm/InitializePasses.h =================================================================== --- include/llvm/InitializePasses.h +++ include/llvm/InitializePasses.h @@ -122,6 +122,7 @@ void initializeDependenceAnalysisWrapperPassPass(PassRegistry&); void initializeDetectDeadLanesPass(PassRegistry&); void initializeDivRemPairsLegacyPassPass(PassRegistry&); +void initializeDevirtPrivateMethodsLegacyPassPass(PassRegistry&); void initializeDomOnlyPrinterPass(PassRegistry&); void initializeDomOnlyViewerPass(PassRegistry&); void initializeDomPrinterPass(PassRegistry&); Index: include/llvm/Transforms/IPO.h =================================================================== --- include/llvm/Transforms/IPO.h +++ include/llvm/Transforms/IPO.h @@ -300,6 +300,8 @@ ModulePass *createWriteThinLTOBitcodePass(raw_ostream &Str, raw_ostream *ThinLinkOS = nullptr); +ModulePass *createDevirtPrivateMethodsLegacyPass(); + } // End llvm namespace #endif Index: lib/Transforms/IPO/DevirtualizePrivateMethods.cpp =================================================================== --- /dev/null +++ lib/Transforms/IPO/DevirtualizePrivateMethods.cpp @@ -0,0 +1,306 @@ +//===- DevirtualizePrivateMethods.cpp - Devirtualize Private Methods ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// module pass +// for each callsite: +// collect the callsite for private methods +// if within lambda find the callee and blacklist +// second pass over the callsites: devirtualize using the WPD helpers +// TODO: Devirt all or none for a method by first collecting the candidates. +// TODO: Disable devirt of init methods. +// TODO: Add null check after devirt +// TODO: Bail out on message passed to id type. +//===----------------------------------------------------------------------===// + +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Utils/ObjCMetadataUtils.h" + +#include + +using namespace llvm; + +#define DEBUG_TYPE "devirtualize-private-methods" + +namespace llvm { +class DevirtPrivateMethods { +public: + DevirtPrivateMethods(Module &M) : M(M) {} + bool run(); + +private: + void getMethodsFromMethodList(GlobalVariable *MList); + Function *getMethod(CallSite CS, Instruction *I) const; + bool isMessageToSelf(CallSite CS, Function &F) const; + void getPrivateMethods(); + void blacklistInvocationsFromBlocks(); + bool getDevirtCandidates(); + bool devirtualize(); + + Module &M; + // All the private methods in this module. + std::set PrivateMethods; + // We analyze blocks to blacklist private methods called by them. + SmallVector BlockFunctions; + // Blacklisted Functions. + std::set BlacklistedFunctions; + // MethodList inside module M. + std::set MList; + // Maps from method names to function definitions. + std::map Vtable; + std::map Candidates; +}; + +struct DevirtPrivateMethodsPass : PassInfoMixin { + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM) { + DevirtPrivateMethods D(M); + D.run(); + return PreservedAnalyses::all(); + } +}; +} // namespace llvm + +static bool isBlock(const Function *F) { + return F->getName().find("block_invoke") != StringRef::npos; +} + +// Skip C-Functions. +static bool skipFunction(const Function *F) { + // Each objective C method is prefixed with: "\01" + // TODO: Find a better way to differentiate from other functions. + return !F->getName().startswith("\01"); +} + +void DevirtPrivateMethods::getMethodsFromMethodList(GlobalVariable *MList) { + ConstantStruct *OldCS = cast(MList->getInitializer()); + ConstantArray *OldCA = cast(OldCS->getOperand(2)); + + // Each element of the ConstantArray should be a ConstantStruct. + for (unsigned I = 0, E = OldCA->getNumOperands(); I < E; ++I) { + ConstantStruct *Entry = cast(OldCA->getOperand(I)); + auto *MethPtr = dyn_cast(Entry->getOperand(2)); + if (!MethPtr) // Protocol Methods List have null method pointer. + continue; + assert(MethPtr->getOpcode() == Instruction::BitCast && + "method type should be Bitcast"); + // Construct a string, then push into MethodNames. + auto I1 = cast(MethPtr->getOperand(0)->stripPointerCasts()); + // auto *GV = cast(MethName->getOperand(0)); + // auto *StrCD = cast(GV->getInitializer()); + // We should get the 1st field for method name. + auto *MethName = dyn_cast(Entry->getOperand(0)); + assert(MethName->getOpcode() == Instruction::GetElementPtr && + "method name should be GetElementPtr Constant"); + // Construct a string, then push into MethodNames. + auto *GV = cast(MethName->getOperand(0)); + assert(Vtable.count(GV) == 0 && "Multiple entries!"); + Vtable[GV] = I1; + } +} + +// 'self' is the first argument to objc_msgsend and to the caller as well. +// define internal void @"\01-[MyObject privMethod:]"(%0* %self, i8* %_cmd) { +// entry: +// %0 = load i8*, i8** @OBJC_SELECTOR_REFERENCES_ +// %1 = bitcast %0* %self to i8* +// call @objc_msgSend to void (i8*, i8*)*)(i8* %1, i8* %0) +// ret void +// } +bool DevirtPrivateMethods::isMessageToSelf(CallSite CS, Function &F) const { + auto FirstArg = CS.getArgument(0); + /// objc_msgSend having StRet as first argument. TODO: devirt these as well. + auto Bitcast = dyn_cast(FirstArg); + if (!Bitcast) + return false; + // self is the first argument of method F. + return (&*F.arg_begin() == Bitcast->getOperand(0)); +} + +// Returns the function called by parsing the callsite, +// selector-references table and method-list table. +// %0 = load i8*, i8** bitcast ( +// %struct._class_t** @"OBJC_CLASSLIST_REFERENCES_$_" to i8**) +// %1 = load i8*, i8** @OBJC_SELECTOR_REFERENCES_ +// %call = call %0* bitcast (i8* (i8*, i8*, ...)* +// @objc_msgSend to %0* (i8*, i8*, i32)*)(i8* %0, i8* %1, i32 5) +Function *DevirtPrivateMethods::getMethod(CallSite CS, Instruction *I) const { + if (CS.isInlineAsm() || CS.isIndirectCall()) + return nullptr; + // Remove bitcast from the function pointer. + Function *F = cast(CS.getCalledValue()->stripPointerCasts()); + // TODO: Handle objc_msgSendSuper(2) + if (!F->getName().equals("objc_msgSend")) + return nullptr; + /// _objc_method struct + /// struct _objc_method { + /// SEL _cmd; + /// char *method_type; + /// char *_imp; + /// } + /// An instance of _objc_method privMethod + /// %struct._objc_method { + /// i8* getelementptr inbounds ([21 x i8], [21 x i8]* @OBJC_METH_VAR_NAME_, + /// i32 0, i32 0), + /// i8* getelementptr inbounds ([8 x i8], [8 x i8]* @OBJC_METH_VAR_TYPE_.4, + /// i32 0, i32 0), + /// i8* bitcast (void (%0*, i8*)* @"\01-[MyObject privMethod]" to i8*) + /// } + auto SelRefLoad = CS.getArgument(1); + + /// objc_msgSend having StRet as first argument. TODO: devirt these as well. + if (!isa(SelRefLoad)) + return nullptr; + LoadInst *LI = cast(SelRefLoad); + auto SelRef = LI->getPointerOperand(); + GlobalVariable *GV = dyn_cast(SelRef); + // The SelRef can be (conditionally) assigned to a global variable, don't + // follow through just bail out for now. + if (!GV) + return nullptr; + /// @OBJC_METH_VAR_NAME_ = private unnamed_addr constant [21 x i8] + /// c"privMethod\00" + /// @OBJC_SELECTOR_REFERENCES_ = private externally_initialized global + /// i8* getelementptr inbounds ([21 x i8], [21 x i8]* @OBJC_METH_VAR_NAME_, + /// i32 0, i32 0) + GlobalVariable *Method = findMethodNameForSelRef(GV); + auto it = Vtable.find(Method); + if (it != Vtable.end()) + return it->second; + + return nullptr; +} + +static bool couldDevirt(const Function *F) { + return !F->isVarArg() && F->hasExactDefinition(); +} + +void DevirtPrivateMethods::getPrivateMethods() { + for (Function &F : M.functions()) { + if (!F.isDeclaration()) { + if (isBlock(&F)) { + BlockFunctions.push_back(&F); + continue; + } + if (F.hasFnAttribute((Attribute::ObjCPrivateMethod))) { + if (couldDevirt(&F)) + PrivateMethods.insert(&F); + } + } + } +} + +void DevirtPrivateMethods::blacklistInvocationsFromBlocks() { + for (Function *F : BlockFunctions) { + for (BasicBlock &BB : *F) { + for (Instruction &I : BB) { + auto CS = CallSite(&I); + if (!CS) + continue; + // Get method name from objc + Function *Callee = getMethod(CS, &I); + if (PrivateMethods.count(Callee)) { + BlacklistedFunctions.insert(Callee); + LLVM_DEBUG(llvm::errs() << "\nBlacklisting: " << Callee->getName()); + } + } + } + } +} + +bool DevirtPrivateMethods::devirtualize() { + // Check that this is call to objc_msgSend. + // Create args using the call-site's arguments, remove sel-ptr etc. + for (auto P : Candidates) { + auto CS = CallSite(P.first); + LLVM_DEBUG(llvm::errs() << "\nPrivate devirt from " << *P.first); + CS.setCalledFunction( + ConstantExpr::getBitCast(P.second, CS.getCalledValue()->getType())); + P.second->removeFnAttr(Attribute::AlwaysInline); + P.second->addFnAttr(Attribute::NoInline); + LLVM_DEBUG(llvm::errs() << "\nPrivate devirt to: " << P.second->getName()); + } + return true; +} + +// From objc_msgSend get the second parameter (selector) +// go through the selector references to get the method reference +// from the method list get the callee +bool DevirtPrivateMethods::getDevirtCandidates() { + bool Changed = false; + for (Function &F : M) { + if (skipFunction(&F)) + continue; + for (BasicBlock &BB : F) { + for (Instruction &I : BB) { + auto CS = CallSite(&I); + if (!CS) + continue; + Function *Callee = getMethod(CS, &I); + if (!Callee) + continue; + // If the called function is one of the private methods + if (!PrivateMethods.count(Callee)) + continue; + if (!isMessageToSelf(CS, F)) + continue; + Candidates[&I] = Callee; + } + } + } + return Changed; +} + +bool DevirtPrivateMethods::run() { + getAllMethodLists(M, MList); + if (MList.size() != 1) { + // TODO: Handle method lists of different kind, this can help + // running this pass in lto-mode but may unnecessarily complicate + // the analysis. + // class methods list and instance methods lists can co-exist. + return false; + } + getMethodsFromMethodList(*MList.begin()); + getPrivateMethods(); + blacklistInvocationsFromBlocks(); + getDevirtCandidates(); + devirtualize(); + return false; +} + +namespace { +struct DevirtPrivateMethodsLegacyPass : public ModulePass { + static char ID; // Pass identification, replacement for typeid + DevirtPrivateMethodsLegacyPass() : ModulePass(ID) { + initializeDevirtPrivateMethodsLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override {} + + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + DevirtPrivateMethods D(M); + return D.run(); + } +}; +} // namespace + +char DevirtPrivateMethodsLegacyPass::ID = 0; +INITIALIZE_PASS(DevirtPrivateMethodsLegacyPass, "devirtualize-private-methods", + "Devirtualize Private Methods in ObjC", false, false) + +ModulePass *llvm::createDevirtPrivateMethodsLegacyPass() { + return new DevirtPrivateMethodsLegacyPass(); +}