diff --git a/llvm/include/llvm/Transforms/Utils/CallbackEncapsulate.h b/llvm/include/llvm/Transforms/Utils/CallbackEncapsulate.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Transforms/Utils/CallbackEncapsulate.h @@ -0,0 +1,111 @@ +//===- CallbackEncapsulate.h - Isolate callbacks in own functns -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper methods to deal with callback wrapper around a call sites. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_UTILS_CALLBACK_ENCAPSULATE_H +#define LLVM_TRANSFORMS_UTILS_CALLBACK_ENCAPSULATE_H + +#include "llvm/IR/CallSite.h" + +namespace llvm { + +/// This method encapsulates the \p Called and the \p Callee function (which can +/// be the same) with new functions that are connected through a callback +/// annotation. The callback annotation uses copies of the arguments and the +/// original ones are still passed. We do this to allow later passes, e.g., +/// argument promotion, to modify the passed arguments without changing the +/// interface of \p Called and \p Callee. This can be good for two reasons: +/// +/// (1) If \p Called is a declaration that has callback behavior and \p Callee +/// is the callback callee we could otherwise not modify the way arguments are +/// passed between them. +/// +/// (2) If \p Callee is passed very large structure we want to unpack it to +/// facilitate later analysis but we would otherwise lack the ability to pack +/// them again to guarantee the same call performance. +/// +/// The new abstract call site and the direct one that with the same callee are +/// tied together through metadata as shown in the example below. +/// +/// Note that the encapsulation does not change the semantic of the code. While +/// there are more functions and calls involved, there is no semantic change. +/// Passes aware and unaware of the encoding can interpret and modify the code. +/// +/// ------------------------------- Before ------------------------------------ +/// void foo() { +/// (A) call Called(p0, p1); +/// } +/// +/// // The definition of Called might not be available. Called can be Callee +/// // or contain call to Callee, e.g., via callback metadata. +/// void Called(arg0, arg1); +/// +/// void Callee(arg2, arg3) { +/// // Callee code +/// } +/// +/// +/// ------------------------------- After ------------------------------------- +/// +/// void foo() { +/// // metadata !rpl_cs !{!1} +/// (A) call before_wrapper(p0, p1, after_wrapper, p0, p1); +/// } +/// +/// __attribute__((callback(callee_w, arg2_w, arg3_w))) +/// void before_wrapper(arg0, arg1, callee_w, arg2_w, arg3_w) { +/// (B) call Called(arg0, arg1); +/// } +/// +/// // The definition of Called might not be available. Called can be Callee +/// // or contain call to Callee. +/// void Called(arg0, arg1); +/// +/// void Callee(arg2, arg3) { +/// // metadata !rpl_acs !{!0} +/// (C) call after_wrapper(arg2, arg3); +/// } +/// +/// void after_wrapper(arg2, arg3) { +/// (D) // Callee code +/// } +/// +/// !0 = {!1} +/// !1 = {!0} +/// +/// In this encoding, the following (abstract) call edges exist: +/// (1) (A) -> before_wrapper [direct] +/// (2) (A) -> after_wrapper [transitive/callback] +/// (3) (B) -> Called [direct] +/// (4) (C) -> after_wrapper [direct] +/// +/// The shown metadata is used to tie (2) and (4) together such that aware users +/// can ignore (4) in favor of (2). If the metadata is corrupted or dropped, the +/// connection cannot be made and (4) has to be taken into account. This for +/// example the case if (B) was inlined. +/// +/// \returns The call/invoke that replaced the one described by \p ACS, (A) in +/// the above examples. +CallBase *encapsulateAbstractCallSite(AbstractCallSite ACS); + +/// Return true if \p CS is a direct call with a replacing abstract call site +/// that should be used for inter-procedural reasoning instead. +/// +/// This function should only be used by abstract call site aware +/// inter-procedural passes. If the return value is true, and the passes will +/// eventually look at all direct and transitive call sites to derive +/// information, they can ignore the direct call site \p CS as there will be an +/// abstract call site that encodes the same call. +bool isDirectCallSiteReplacedByAbstractCallSite(ImmutableCallSite CS); + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_UTILS_CALLBACK_ENCAPSULATE_H diff --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt --- a/llvm/lib/Transforms/Utils/CMakeLists.txt +++ b/llvm/lib/Transforms/Utils/CMakeLists.txt @@ -6,6 +6,7 @@ BuildLibCalls.cpp BypassSlowDivision.cpp CallPromotionUtils.cpp + CallbackEncapsulate.cpp CanonicalizeAliases.cpp CloneFunction.cpp CloneModule.cpp diff --git a/llvm/lib/Transforms/Utils/CallbackEncapsulate.cpp b/llvm/lib/Transforms/Utils/CallbackEncapsulate.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Transforms/Utils/CallbackEncapsulate.cpp @@ -0,0 +1,229 @@ +//===- CallbackEncapsulate.cpp -- Encapsulate callbacks in own functions --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper methods to deal with callback wrapper around a call sites. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/CallbackEncapsulate.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/Support/Error.h" + +using namespace llvm; + +#define DEBUG_TYPE "callback-encapsulate" + +STATISTIC(NumCallbacksEncapsulated, "Number of callbacks encapsulated"); + +static constexpr char ReplicatedCallSiteString[] = "rpl_cs"; +static constexpr char ReplicatedAbstractCallSiteString[] = "rpl_acs"; +static constexpr char BeforeWrapperSuffix[] = ".before_wrapper"; +static constexpr char AfterWrapperSuffix[] = ".after_wrapper"; + +static void replaceAlInstUsesWith(Value &Old, Value &New) { + if (!isa(Old)) + return Old.replaceAllUsesWith(&New); + SmallVector Uses; + for (Use &U : Old.uses()) + if (isa(U.getUser())) + Uses.push_back(&U); + for (Use *U : Uses) + U->set(&New); +} + +/// Helper to extract the "rpl_cs" and "rpl_acs" metadata from a call site. +static std::pair getMDNodeOperands(ImmutableCallSite ICS, + StringRef MDString) { + std::pair Ret; + const Instruction *I = ICS.getInstruction(); + if (!I) + return Ret; + + MDNode *MD = I->getMetadata(MDString); + if (!MD || MD->getNumOperands() != 1) + return Ret; + + MDNode *OpMD = dyn_cast_or_null(MD->getOperand(0).get()); + return {MD, OpMD}; +} + +bool llvm::isDirectCallSiteReplacedByAbstractCallSite(ImmutableCallSite CS) { + // Check if there is "rpl_acs" metadata on the call site and it matches an + // abstract call site with "rpl_cs" metadata. This means this call site is + // equivalent to the abstract call site it matches. + std::pair RplCSOpMD = + getMDNodeOperands(CS, ReplicatedCallSiteString); + if (!RplCSOpMD.first || !RplCSOpMD.second) + return false; + + const Function *Callee = CS.getCalledFunction(); + for (const Use &U : Callee->uses()) { + AbstractCallSite ACS(&U); + if (!ACS || !ACS.isCallbackCall()) + continue; + + std::pair RplACSOpMD = + getMDNodeOperands(ACS.getCallSite(), ReplicatedAbstractCallSiteString); + if (RplACSOpMD.first == RplCSOpMD.second && + RplACSOpMD.second == RplCSOpMD.first) + return true; + } + + return false; +} + +CallBase *llvm::encapsulateAbstractCallSite(AbstractCallSite ACS) { + assert(ACS && "Expected valid abstract call site!"); + + bool IsCallback = ACS.isCallbackCall(); + CallBase *CB = cast(ACS.getInstruction()); + Function &DirectCallee = *CB->getCalledFunction(); + Function &TransitiveCallee = *ACS.getCalledFunction(); + + // If we have a direct call, the transitive callee is the same as the direct + // callee. However, for callback calls this might not be the case. + assert((IsCallback || (&DirectCallee == &TransitiveCallee)) && + "Broken invariant"); + + // We do not allow varargs for now. + if (DirectCallee.isVarArg() || TransitiveCallee.isVarArg()) + return nullptr; + + Module &M = *DirectCallee.getParent(); + LLVMContext &Ctx = M.getContext(); + + FunctionType *AfterWrapperTy = TransitiveCallee.getFunctionType(); + Function *AfterWrapper = + Function::Create(AfterWrapperTy, GlobalValue::InternalLinkage, + TransitiveCallee.getName() + AfterWrapperSuffix, M); + AfterWrapper->setAttributes(TransitiveCallee.getAttributes()); + auto &AfterWrapperBlockList = AfterWrapper->getBasicBlockList(); + auto WrapperAI = AfterWrapper->arg_begin(); + for (Argument &Arg : TransitiveCallee.args()) { + Argument *WrapperArg = &*(WrapperAI++); + Arg.replaceAllUsesWith(WrapperArg); + WrapperArg->setName(Arg.getName()); + } + AfterWrapperBlockList.splice(AfterWrapperBlockList.begin(), + TransitiveCallee.getBasicBlockList()); + BasicBlock *AfterEntryBB = + BasicBlock::Create(Ctx, "entry", &TransitiveCallee); + + // The after wrapper has the same interface as the transitive callee. The + // transitive call will just redirect to the after wrapper, thus simply pass + // all arguments along. + SmallVector Args; + Args.reserve(TransitiveCallee.arg_size()); + for (Argument &Arg : TransitiveCallee.args()) + Args.push_back(&Arg); + + CallInst *AfterWrapperCB = + CallInst::Create(AfterWrapperTy, AfterWrapper, Args, + TransitiveCallee.getName() + ".acs", AfterEntryBB); + ReturnInst::Create( + Ctx, + TransitiveCallee.getReturnType()->isVoidTy() ? nullptr : AfterWrapperCB, + AfterEntryBB); + + // Prepare the arguments for the call that is also an abstract call site. + // Every argument is passed at most twice and the callee of the abstract call + // site is passed in the middle. + Args.clear(); + Args.reserve(CB->getNumArgOperands() * 2 + 1); + Args.append(CB->arg_begin(), CB->arg_end()); + + int CBCalleeIdx = Args.size(); + Args.push_back(AfterWrapper); + + SmallVector PayloadIndices; + AttributeList CalleeFnAttrs = TransitiveCallee.getAttributes(); + AttributeList ExtCalledAttrs = DirectCallee.getAttributes(); + + // Collect the arguments that go into the abstract call. These are all + // arguments if the call site ACS was direct or the subset that the abstract + // call site ACS actually used. Given that we might skip arguments we need to + // track the payload indices for the callback encoding as well. Finally, we + // keep the attributes of the original arguments we duplicate around. + for (unsigned u = 0, e = TransitiveCallee.arg_size(); u < e; u++) { + int OpIdx = ACS.getCallArgOperandNo(u); + if (OpIdx < 0) + continue; + PayloadIndices.push_back(Args.size()); + AttributeSet CalleeFnParamAttrs = CalleeFnAttrs.getParamAttributes(u); + ExtCalledAttrs = ExtCalledAttrs.addParamAttributes( + Ctx, Args.size(), AttrBuilder(CalleeFnParamAttrs)); + Args.push_back(CB->getOperand(OpIdx)); + } + + SmallVector ArgTypes; + for (Value *V : Args) + ArgTypes.push_back(V->getType()); + + FunctionType *BeforeWrapperTy = + FunctionType::get(DirectCallee.getReturnType(), ArgTypes, false); + Function *BeforeWrapper = + Function::Create(BeforeWrapperTy, GlobalValue::InternalLinkage, + DirectCallee.getName() + BeforeWrapperSuffix, M); + BeforeWrapper->setAttributes(ExtCalledAttrs); + + MDBuilder MDB(Ctx); + SmallVector CBEncodings; + CBEncodings.push_back( + MDB.createCallbackEncoding(CBCalleeIdx, PayloadIndices, + /* VarArgsArePassed */ false)); + + // If the direct callee already has callback metadata we copy it to the before + // wrapper which has the same behavior and argument prefix. + MDNode *ExistingCBMD = DirectCallee.getMetadata(LLVMContext::MD_callback); + if (ExistingCBMD) + CBEncodings.append(ExistingCBMD->op_begin(), ExistingCBMD->op_end()); + BeforeWrapper->addMetadata(LLVMContext::MD_callback, + *MDNode::get(Ctx, CBEncodings)); + + auto *BeforeWrapperCB = + CallInst::Create(BeforeWrapper->getFunctionType(), BeforeWrapper, Args, + TransitiveCallee.getName() + ".cs", CB); + replaceAlInstUsesWith(*CB, *BeforeWrapperCB); + + // Create and attach the encoding metadata to the two call site (one + // abstract, one direct) of the called wrapper function. + MDNode *BeforeWrapperCBMD = MDNode::get(Ctx, {nullptr}); + BeforeWrapperCB->setMetadata(ReplicatedAbstractCallSiteString, + BeforeWrapperCBMD); + + MDNode *AfterWrapperCBMD = MDNode::get(Ctx, {BeforeWrapperCBMD}); + AfterWrapperCB->setMetadata(ReplicatedCallSiteString, AfterWrapperCBMD); + BeforeWrapperCBMD->replaceOperandWith(0, AfterWrapperCBMD); + + BasicBlock *BeforeEntryBB = BasicBlock::Create(Ctx, "entry", BeforeWrapper); + ReturnInst *RI = ReturnInst::Create( + Ctx, DirectCallee.getReturnType()->isVoidTy() ? nullptr : CB, + BeforeEntryBB); + // Reuse the old call in the new wrapper. + CB->moveBefore(RI); + + // Set the names of and rewire arguments. + auto CalledAI = DirectCallee.arg_begin(); + auto BeforeWrapperAI = BeforeWrapper->arg_begin(), + BeforeWrapperAE = BeforeWrapper->arg_end(); + for (unsigned u = 0, e = CB->getNumArgOperands(); u < e; u++) { + BeforeWrapperAI->setName((CalledAI++)->getName()); + CB->setArgOperand(u, &*(BeforeWrapperAI++)); + } + + CalledAI = DirectCallee.arg_begin(); + for (unsigned u = 0, e = CB->getNumArgOperands(); u < e; u++) + BeforeWrapperAI->setName((CalledAI++)->getName()); + + return BeforeWrapperCB; +} diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt --- a/llvm/unittests/Transforms/Utils/CMakeLists.txt +++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -9,6 +9,7 @@ add_llvm_unittest(UtilsTests ASanStackFrameLayoutTest.cpp BasicBlockUtilsTest.cpp + CallbackEncapsulateTest.cpp CloningTest.cpp CodeExtractorTest.cpp CodeMoverUtilsTest.cpp diff --git a/llvm/unittests/Transforms/Utils/CallbackEncapsulateTest.cpp b/llvm/unittests/Transforms/Utils/CallbackEncapsulateTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Transforms/Utils/CallbackEncapsulateTest.cpp @@ -0,0 +1,502 @@ +//===- CallbackEncapsulateTest.cpp - Unit tests for callback encapsulate --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/CallbackEncapsulate.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/DIBuilder.h" +#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +static std::unique_ptr parseIR(LLVMContext &C, StringRef IR) { + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("CallbackEncapsulate", errs()); + return Mod; +} + +static void verifyModulePtr(std::unique_ptr &M) { + ASSERT_NE(M, nullptr); + ASSERT_FALSE(verifyModule(*M, &errs())); +} + +static void verifNoReplCalls(std::unique_ptr &M) { + for (Function &F : *M) + for (Instruction &I : instructions(&F)) + if (auto *CI = dyn_cast(&I)) + EXPECT_FALSE( + isDirectCallSiteReplacedByAbstractCallSite(ImmutableCallSite(CI))); +} + +static CallBase *verifyEncapsulateCallSites(std::unique_ptr &M, + AbstractCallSite ACS) { + CallBase *CB = encapsulateAbstractCallSite(ACS); + EXPECT_TRUE(CB); + EXPECT_FALSE(verifyModule(*M, &errs())); + return CB; +} + +static void verifyEncapsulatedState(std::unique_ptr &M, + Function *Caller, Use *CalledUse, + Use *CalleeUse) { + AbstractCallSite CalledACS(CalledUse); + ASSERT_TRUE(CalledACS); + Function *Called = CalledACS.getCalledFunction(); + ASSERT_NE(Called, nullptr); + + AbstractCallSite CalleeACS(CalleeUse); + ASSERT_TRUE(CalleeACS); + Function *Callee = CalleeACS.getCalledFunction(); + ASSERT_NE(Callee, nullptr); + + Function *BeforeWrapper = + M->getFunction((Called->getName() + ".before_wrapper").str()); + ASSERT_NE(BeforeWrapper, nullptr); + EXPECT_TRUE(BeforeWrapper->hasLocalLinkage()); + BeforeWrapper->setName(BeforeWrapper->getName() + ".seen"); + + Function *AfterWrapper = + M->getFunction((Callee->getName() + ".after_wrapper").str()); + ASSERT_NE(AfterWrapper, nullptr); + EXPECT_TRUE(AfterWrapper->hasLocalLinkage()); + AfterWrapper->setName(AfterWrapper->getName() + ".seen"); + + EXPECT_EQ(AfterWrapper->getFunctionType(), Callee->getFunctionType()); + + // Verify we see the right number of function users (this is to some degree + // specific to the test structure). + EXPECT_EQ(Caller->getNumUses(), 0U); + EXPECT_EQ(AfterWrapper->getNumUses(), BeforeWrapper->getNumUses() + 1); + + AbstractCallSite BeforeACS(&*BeforeWrapper->use_begin()); + bool FoundDirectAfterACS = false, FoundCallbackAfterACS = false; + ASSERT_GE(AfterWrapper->getNumUses(), 1); + AbstractCallSite AfterDirectACS(&*AfterWrapper->use_begin()), + AfterCallbackACS(&*AfterWrapper->use_begin()); + for (const Use &U : AfterWrapper->uses()) { + AbstractCallSite ACS(&U); + ASSERT_TRUE(ACS); + if (ACS.isDirectCall()) { + ASSERT_FALSE(FoundDirectAfterACS); + AfterDirectACS = ACS; + FoundDirectAfterACS = true; + } else if (!FoundCallbackAfterACS) { + AfterCallbackACS = ACS; + FoundCallbackAfterACS = true; + } + } + ASSERT_TRUE(FoundDirectAfterACS); + ASSERT_TRUE(FoundCallbackAfterACS); + ASSERT_TRUE(AfterDirectACS); + ASSERT_TRUE(AfterCallbackACS); + + // Verify the call sites are as expected. + ASSERT_TRUE(CalleeACS); + if (Called == Callee) { + EXPECT_TRUE(CalleeACS.isDirectCall()); + } else { + EXPECT_TRUE(CalleeACS.isCallbackCall()); + } + ASSERT_TRUE(BeforeACS); + EXPECT_TRUE(BeforeACS.isDirectCall()); + EXPECT_TRUE(AfterDirectACS.isDirectCall()); + EXPECT_TRUE(AfterCallbackACS.isCallbackCall()); + + EXPECT_TRUE( + isDirectCallSiteReplacedByAbstractCallSite(AfterDirectACS.getCallSite())); + EXPECT_FALSE(isDirectCallSiteReplacedByAbstractCallSite( + AfterCallbackACS.getCallSite())); + + EXPECT_FALSE( + isDirectCallSiteReplacedByAbstractCallSite(CalleeACS.getCallSite())); + EXPECT_FALSE( + isDirectCallSiteReplacedByAbstractCallSite(BeforeACS.getCallSite())); + + ASSERT_TRUE(isa(CalleeACS.getInstruction())); + ASSERT_TRUE(isa(BeforeACS.getInstruction())); + auto *BeforeWrapperCI = cast(BeforeACS.getInstruction()); + ASSERT_TRUE(isa(AfterDirectACS.getInstruction())); + auto *AfterWrapperCI = cast(AfterDirectACS.getInstruction()); + + auto CalleeAttrs = Callee->getAttributes(); + auto BeforeAttrs = BeforeWrapper->getAttributes(); + auto AfterAttrs = AfterWrapper->getAttributes(); + EXPECT_EQ(CalleeAttrs.getRetAttributes(), BeforeAttrs.getRetAttributes()); + EXPECT_EQ(CalleeAttrs.getRetAttributes(), AfterAttrs.getRetAttributes()); + + // Verify the arguments and their mapping is as expected. + ASSERT_EQ(Callee->arg_size(), AfterWrapper->arg_size()); + if (Called == Callee) { + ASSERT_EQ(Callee->arg_size() * 2 + 1, BeforeWrapper->arg_size()); + } else { + ASSERT_LE(Callee->arg_size() + 1, BeforeWrapper->arg_size()); + } + + for (unsigned ArgNo = 0; ArgNo < CalleeACS.getNumArgOperands(); ++ArgNo) { + int OpIdx = CalleeACS.getCallArgOperandNo(ArgNo); + if (OpIdx < 0) + continue; + + auto CalleeAI = Callee->arg_begin() + ArgNo; + auto AfterAI = AfterWrapper->arg_begin() + ArgNo; + auto BeforeAI = BeforeWrapper->arg_begin() + OpIdx; + EXPECT_EQ(CalleeAI->getType(), AfterAI->getType()); + EXPECT_EQ(CalleeAI->getType(), BeforeAI->getType()); + + EXPECT_EQ(CalleeAttrs.getAttributes(AttributeList::FirstArgIndex + ArgNo), + BeforeAttrs.getAttributes(AttributeList::FirstArgIndex + OpIdx)); + EXPECT_EQ(CalleeAttrs.getAttributes(AttributeList::FirstArgIndex + ArgNo), + AfterAttrs.getAttributes(AttributeList::FirstArgIndex + ArgNo)); + + EXPECT_EQ(CalleeAI->getNumUses(), 1U); + EXPECT_EQ(BeforeAI->getNumUses(), 1U); + EXPECT_EQ(CalleeAI->user_back(), AfterWrapperCI); + ASSERT_EQ(BeforeAI->user_back(), CalledACS.getInstruction()); + } + + // Second part of the before wrapper arguments is dependent on the call(back) + // but we know it starts with the after wrapper. + auto BeforeAI = BeforeWrapper->arg_begin() + Called->arg_size(); + EXPECT_EQ(BeforeWrapperCI->getArgOperand(BeforeAI->getArgNo()), AfterWrapper); + ++BeforeAI; + + for (unsigned ArgNo = 0; ArgNo < CalleeACS.getNumArgOperands(); ++ArgNo) { + int OpIdx = CalleeACS.getCallArgOperandNo(ArgNo); + if (OpIdx < 0) + continue; + + unsigned BeforeArgNo = BeforeAI->getArgNo(); + + auto CalleeAI = Callee->arg_begin() + ArgNo; + EXPECT_EQ(CalleeAI->getType(), BeforeAI->getType()); + + EXPECT_EQ( + CalleeAttrs.getAttributes(AttributeList::FirstArgIndex + ArgNo), + BeforeAttrs.getAttributes(AttributeList::FirstArgIndex + BeforeArgNo)); + + EXPECT_EQ(BeforeAI->getNumUses(), 0U); + + ++BeforeAI; + } +} + +class CallbackEncapsulateTest : public ::testing::Test { +protected: + LLVMContext C; +}; + +TEST_F(CallbackEncapsulateTest, CallbackEncapsulateDirectCall0) { + + StringRef ModuleAssembly = R"( +define noalias double* @callee() { +entry: + ret double* null; +} + +define double* @caller(i32 %unused) { +entry: + %call = call double* @callee() + ret double* %call +} + )"; + + std::unique_ptr M = parseIR(C, ModuleAssembly); + verifyModulePtr(M); + verifNoReplCalls(M); + + Function *Caller = M->getFunction("caller"); + Function *Callee = M->getFunction("callee"); + Function *Called = Callee; + Use *CalleeUse = &*Callee->use_begin(); + Use *CalledUse = &*Called->use_begin(); + verifyEncapsulateCallSites(M, AbstractCallSite(CalleeUse)); + verifyEncapsulatedState(M, Caller, CalledUse, CalleeUse); +} + +TEST_F(CallbackEncapsulateTest, CallbackEncapsulateDirectCall1) { + + StringRef ModuleAssembly = R"( +define double @callee(i32 %i0, double %d0, i16 signext %s0, i16 signext %s1, double %d1, i32 %i1) { +entry: + %conv = sext i32 %i0 to i64 + call void @use(i64 %conv) + %conv1 = fptosi double %d0 to i64 + call void @use(i64 %conv1) + %conv2 = sext i16 %s0 to i64 + call void @use(i64 %conv2) + %conv3 = sext i32 %i1 to i64 + call void @use(i64 %conv3) + %conv4 = fptosi double %d1 to i64 + call void @use(i64 %conv4) + %conv5 = sext i16 %s1 to i64 + call void @use(i64 %conv5) + %add = fadd double %d0, %d1 + ret double %add +} + +declare void @use(i64) + +define double @caller(i32 %i, double %d) { +entry: + %conv = sitofp i32 %i to double + %mul = fmul double %conv, %d + %conv1 = fptosi double %mul to i16 + %call = call double @callee(i32 %i, double %d, i16 signext %conv1, i16 signext %conv1, double %d, i32 %i) + ret double %call +} + )"; + + std::unique_ptr M = parseIR(C, ModuleAssembly); + verifyModulePtr(M); + verifNoReplCalls(M); + + Function *Caller = M->getFunction("caller"); + Function *Callee = M->getFunction("callee"); + Function *Called = Callee; + Use *CalleeUse = &*Callee->use_begin(); + Use *CalledUse = &*Called->use_begin(); + verifyEncapsulateCallSites(M, AbstractCallSite(CalleeUse)); + verifyEncapsulatedState(M, Caller, CalledUse, CalleeUse); +} + +TEST_F(CallbackEncapsulateTest, CallbackEncapsulateDirectCalls0) { + + StringRef ModuleAssembly = R"( +define double* @callee() { +entry: + ret double* null; +} + +define double* @caller(i32 %unused) { +entry: + %call0 = call double* @callee() + %call1 = call double* @callee() + ret double* %call0 +} + )"; + + std::unique_ptr M = parseIR(C, ModuleAssembly); + verifyModulePtr(M); + verifNoReplCalls(M); + + Function *Caller = M->getFunction("caller"); + Function *Callee = M->getFunction("callee"); + Function *Called = Callee; + Use *CalledUse = &*Called->use_begin(); + Use *CalleeUse0 = &*Callee->use_begin(); + Use *CalleeUse1 = &*(++Callee->use_begin()); + verifyEncapsulateCallSites(M, AbstractCallSite(CalleeUse0)); + verifyEncapsulatedState(M, Caller, CalledUse, CalleeUse0); + verifyEncapsulateCallSites(M, AbstractCallSite(CalleeUse1)); + verifyEncapsulatedState(M, Caller, CalledUse, CalleeUse1); +} + +TEST_F(CallbackEncapsulateTest, CallbackEncapsulateTransitiveCall0) { + + StringRef ModuleAssembly = R"( +%union.pthread_attr_t = type { i64, [48 x i8] } + +define dso_local i32 @caller(i8* %arg) { +entry: + %thread = alloca i64, align 8 + store i8 0, i8* %arg + %call = call i32 @pthread_create(i64* nonnull %thread, %union.pthread_attr_t* null, i8* (i8*)* nonnull @callee, i8* %arg) + ret i32 0 +} + +declare !callback !0 dso_local i32 @pthread_create(i64*, %union.pthread_attr_t*, i8* (i8*)*, i8*) + +define internal i8* @callee(i8* %arg) { +entry: + %l = load i8, i8* %arg + %add = add i8 %l, 1 + store i8 %add, i8* %arg + ret i8* %arg +} + +!1 = !{i64 2, i64 3, i1 false} +!0 = !{!1} + )"; + + std::unique_ptr M = parseIR(C, ModuleAssembly); + verifyModulePtr(M); + verifNoReplCalls(M); + + Function *Caller = M->getFunction("caller"); + Function *Called = M->getFunction("pthread_create"); + Function *Callee = M->getFunction("callee"); + Use *CalledUse = &*Called->use_begin(); + verifyEncapsulateCallSites(M, AbstractCallSite(&*Callee->use_begin())); + verifyEncapsulatedState(M, Caller, CalledUse, &*Callee->use_begin()); +} + +TEST_F(CallbackEncapsulateTest, CallbackEncapsulateTransitiveCalls0) { + + StringRef ModuleAssembly = R"( +%union.pthread_attr_t = type { i64, [48 x i8] } + +define dso_local i32 @caller(i8* %arg) { +entry: + %thread = alloca i64, align 8 + store i8 0, i8* %arg + %call0 = call i32 @pthread_create(i64* nonnull %thread, %union.pthread_attr_t* null, i8* (i8*)* nonnull @callee, i8* %arg) + %call1 = call i32 @pthread_create(i64* nonnull %thread, %union.pthread_attr_t* null, i8* (i8*)* nonnull @callee, i8* %arg) + %call2 = call i32 @pthread_create(i64* nonnull %thread, %union.pthread_attr_t* null, i8* (i8*)* nonnull @callee, i8* %arg) + ret i32 0 +} + +declare !callback !0 dso_local i32 @pthread_create(i64*, %union.pthread_attr_t*, i8* (i8*)*, i8*) + +define internal i8* @callee(i8* %arg) { +entry: + %l = load i8, i8* %arg + %add = add i8 %l, 1 + store i8 %add, i8* %arg + ret i8* %arg +} + +!1 = !{i64 2, i64 3, i1 false} +!0 = !{!1} + )"; + + std::unique_ptr M = parseIR(C, ModuleAssembly); + verifyModulePtr(M); + verifNoReplCalls(M); + + Function *Caller = M->getFunction("caller"); + Function *Called = M->getFunction("pthread_create"); + Function *Callee = M->getFunction("callee"); + + EXPECT_EQ(Callee->getNumUses(), 3U); + EXPECT_EQ(Called->getNumUses(), 3U); + auto CalledUI = Called->use_begin(); + Use *CalledUse = &*(CalledUI++); + Instruction *EntryIt = Caller->getEntryBlock().getFirstNonPHI(); + CallBase *Call0 = cast(EntryIt->getNextNode()->getNextNode()); + CallBase *Call1 = cast(Call0->getNextNode()); + CallBase *Call2 = cast(Call1->getNextNode()); + Use *CalledUse0 = &Call0->getCalledOperandUse(); + Call0 = + verifyEncapsulateCallSites(M, AbstractCallSite(&Call0->getOperandUse(2))); + ASSERT_TRUE(Call0); + verifyEncapsulatedState(M, Caller, CalledUse0, &Call0->getOperandUse(2)); + Use *CalledUse1 = &Call1->getCalledOperandUse(); + Call1 = + verifyEncapsulateCallSites(M, AbstractCallSite(&Call1->getOperandUse(2))); + ASSERT_TRUE(Call1); + verifyEncapsulatedState(M, Caller, CalledUse1, &Call1->getOperandUse(2)); + Use *CalledUse2 = &Call2->getCalledOperandUse(); + Call2 = + verifyEncapsulateCallSites(M, AbstractCallSite(&Call2->getOperandUse(2))); + ASSERT_TRUE(Call2); + verifyEncapsulatedState(M, Caller, CalledUse2, &Call2->getOperandUse(2)); + EXPECT_EQ(Callee->getNumUses(), 3U); +} + +TEST_F(CallbackEncapsulateTest, CallbackEncapsulateTransitiveCalls1) { + + StringRef ModuleAssembly = R"( +%union.pthread_attr_t = type { i64, [48 x i8] } + +define dso_local i32 @caller(i8* %arg) { +entry: + %thread = alloca i64, align 8 + store i8 0, i8* %arg + %call0 = call i32 @pthread_create(i64* nonnull %thread, %union.pthread_attr_t* null, i8* (i8*)* nonnull @callee0, i8* %arg) + %call1 = call i32 @pthread_create(i64* nonnull %thread, %union.pthread_attr_t* null, i8* (i8*)* nonnull @callee1, i8* %arg) + %call2 = call i32 @pthread_create(i64* nonnull %thread, %union.pthread_attr_t* null, i8* (i8*)* nonnull @callee2, i8* %arg) + ret i32 0 +} + +declare !callback !0 dso_local i32 @pthread_create(i64*, %union.pthread_attr_t*, i8* (i8*)*, i8*) + +define internal i8* @callee0(i8* %arg) { +entry: + %l = load i8, i8* %arg + %add = add i8 %l, 1 + store i8 %add, i8* %arg + ret i8* %arg +} +define internal i8* @callee1(i8* %arg) { +entry: + %l = load i8, i8* %arg + %add = add i8 %l, 1 + store i8 %add, i8* %arg + ret i8* %arg +} +define internal i8* @callee2(i8* %arg) { +entry: + %l = load i8, i8* %arg + %add = add i8 %l, 1 + store i8 %add, i8* %arg + ret i8* %arg +} + +!1 = !{i64 2, i64 3, i1 false} +!0 = !{!1} + )"; + + std::unique_ptr M = parseIR(C, ModuleAssembly); + verifyModulePtr(M); + verifNoReplCalls(M); + + Function *Caller = M->getFunction("caller"); + Function *Called = M->getFunction("pthread_create"); + Function *Callee0 = M->getFunction("callee0"); + Function *Callee1 = M->getFunction("callee1"); + Function *Callee2 = M->getFunction("callee2"); + + EXPECT_EQ(Callee0->getNumUses(), 1U); + EXPECT_EQ(Callee1->getNumUses(), 1U); + EXPECT_EQ(Callee2->getNumUses(), 1U); + EXPECT_EQ(Called->getNumUses(), 3U); + auto CalledUI = Called->use_begin(); + Use *CalledUse = &*(CalledUI++); + Instruction *EntryIt = Caller->getEntryBlock().getFirstNonPHI(); + CallBase *Call0 = cast(EntryIt->getNextNode()->getNextNode()); + CallBase *Call1 = cast(Call0->getNextNode()); + CallBase *Call2 = cast(Call1->getNextNode()); + Use *CalledUse0 = &Call0->getCalledOperandUse(); + Call0 = + verifyEncapsulateCallSites(M, AbstractCallSite(&Call0->getOperandUse(2))); + ASSERT_TRUE(Call0); + verifyEncapsulatedState(M, Caller, CalledUse0, &Call0->getOperandUse(2)); + Use *CalledUse1 = &Call1->getCalledOperandUse(); + Call1 = + verifyEncapsulateCallSites(M, AbstractCallSite(&Call1->getOperandUse(2))); + ASSERT_TRUE(Call1); + verifyEncapsulatedState(M, Caller, CalledUse1, &Call1->getOperandUse(2)); + Use *CalledUse2 = &Call2->getCalledOperandUse(); + Call2 = + verifyEncapsulateCallSites(M, AbstractCallSite(&Call2->getOperandUse(2))); + ASSERT_TRUE(Call2); + verifyEncapsulatedState(M, Caller, CalledUse2, &Call2->getOperandUse(2)); + EXPECT_EQ(Callee0->getNumUses(), 1U); + EXPECT_EQ(Callee1->getNumUses(), 1U); + EXPECT_EQ(Callee2->getNumUses(), 1U); +} + +} // namespace