Index: include/llvm/IR/CallSite.h =================================================================== --- include/llvm/IR/CallSite.h +++ include/llvm/IR/CallSite.h @@ -686,6 +686,163 @@ User::op_iterator getCallee() const; }; +/// AbstractCallSite +/// +/// An abstract call site is a wrapper that allows to treat direct, +/// indirect, and callback calls the same. If an abstract call site +/// represents a direct or indirect call site it behaves like a stripped +/// down version of a normal call site object. The abstract call site can +/// also represent a callback call, thus the fact that the initially +/// called function (=broker) may invoke a third one (=callback callee). +/// In this case, the abstract call side hides the middle man, hence the +/// broker function. The result is a representation of the callback call, +/// inside the broker, but in the context of the original instruction. +/// +/// There are up to three functions involved when we talk about callback call +/// sites. The caller (1), which invokes the broker function. The broker +/// function (2), that may or may not invoke the callee. And finally the callee +/// (3), which is the target of the callback call. +/// +/// The abstract call site will handle the mapping from parameters to arguments +/// depending on the semantic of the broker function. However, it is important +/// to note that the mapping is often partial. Thus, some arguments of the +/// call/invoke instruction are mapped to parameters of the callee while others +/// are not. +class AbstractCallSite { +public: + + /// The encoding of a callback with regards to the underlying instruction. + struct CallbackInfo { + + /// For direct/indirect calls the parameter encoding is empty. If it is not, + /// the abstract call site represents a callback. In that case, the first + /// element of the encoding vector represents which argument of the call + /// site CS is the callback callee. The remaining elements map parameters + /// (identified by their position) to the arguments that will be passed + /// through (also identified by position but in the call site instruction). + /// Every -1 entry represents an unknown value that is passed to the callee. + using ParameterEncodingTy = SmallVector; + ParameterEncodingTy ParameterEncoding; + + }; + +private: + + /// The underlying call site: + /// caller -> callee, if this is a direct or indirect call site + /// caller -> broker function, if this is a callback call site + CallSite CS; + + /// The encoding of a callback with regards to the underlying instruction. + CallbackInfo CI; + +public: + /// Sole constructor for abstract call sites (ACS). + AbstractCallSite(const Use *U); + + /// Conversion operator to conveniently check for a valid/initialized ACS. + explicit operator bool() const { return (bool)CS; } + + /// Return the underlying instruction. + Instruction *getInstruction() const { return CS.getInstruction(); } + + /// Return the call site abstraction for the underlying instruction. + CallSite getCallSite() const { return CS; } + + /// Return true if this ACS represents a direct call. + bool isDirectCall() const { + return !isCallbackCall() && !CS.isIndirectCall(); + } + + /// Return true if this ACS represents an indirect call. + bool isIndirectCall() const { + return !isCallbackCall() && CS.isIndirectCall(); + } + + /// Return true if this ACS represents a callback call. + bool isCallbackCall() const { + // For a callback call site the callee is ALWAYS stored first in the + // transitive values vector. Thus, a non-empty vector indicates a callback. + return !CI.ParameterEncoding.empty(); + } + + /// Return true if @p UI is the use that defines the callee of this ACS. + bool isCallee(Value::const_user_iterator UI) const { + return isCallee(&UI.getUse()); + } + + /// Return true if @p U is the use that defines the callee of this ACS. + bool isCallee(const Use *U) const { + if (isDirectCall()) + return CS.isCallee(U); + + assert(!CI.ParameterEncoding.empty() && + "Callback without parameter encoding!"); + return (int)CS.getArgumentNo(U) == CI.ParameterEncoding[0]; + } + + /// Return the number of parameters of the callee. + unsigned getNumArgOperands() const { + if (isDirectCall()) + return CS.getNumArgOperands(); + // Subtract 1 for the callee encoding. + return CI.ParameterEncoding.size() - 1; + } + + /// Return the operand index of the underlying instruction associated with @p + /// Arg. + int getCallArgOperandNo(Argument &Arg) const { + return getCallArgOperandNo(Arg.getArgNo()); + } + + /// Return the operand index of the underlying instruction associated with + /// the function parameter number @p ArgNo. + int getCallArgOperandNo(unsigned ArgNo) const { + if (isDirectCall()) + return ArgNo; + // Add 1 for the callee encoding. + return CI.ParameterEncoding[ArgNo + 1]; + } + + /// Return the operand of the underlying instruction associated with @p Arg. + Value *getCallArgOperand(Argument &Arg) const { + return getCallArgOperand(Arg.getArgNo()); + } + + /// Return the operand of the underlying instruction associated with the + /// function parameter number @p ArgNo. + Value *getCallArgOperand(unsigned ArgNo) const { + if (isDirectCall()) + return CS.getArgOperand(ArgNo); + // Add 1 for the callee encoding. + return CI.ParameterEncoding[ArgNo + 1] >= 0 + ? CS.getArgOperand(CI.ParameterEncoding[ArgNo + 1]) + : nullptr; + } + + /// Return the operand index of the underlying instruction associated with the + /// callee of this ACS. Only valid for callback calls! + int getCallArgOperandNoForCallee() const { + assert(isCallbackCall()); + assert(CI.ParameterEncoding.size() && CI.ParameterEncoding[0] >= 0); + return CI.ParameterEncoding[0]; + } + + /// Return the pointer to function that is being called. + Value *getCalledValue() const { + if (isDirectCall()) + return CS.getCalledValue(); + return CS.getArgOperand(getCallArgOperandNoForCallee()); + } + + /// Return the function being called if this is a direct call, otherwise + /// return null (if it's an indirect call). + Function *getCalledFunction() const { + Value *V = getCalledValue(); + return V ? dyn_cast(V->stripPointerCasts()) : nullptr; + } +}; + template <> struct DenseMapInfo { using BaseInfo = DenseMapInfo; Index: lib/IR/AbstractCallSite.cpp =================================================================== --- /dev/null +++ lib/IR/AbstractCallSite.cpp @@ -0,0 +1,137 @@ +//===-- AbstractCallSite.cpp - Implementation of abstract call sites ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements abstract call sites which unify the interface for +// direct, indirect, and callback call sites. +// +// For more information see: +// https://llvm.org/devmtg/2018-10/talk-abstracts.html#talk20 +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/IR/CallSite.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; + +#define DEBUG_TYPE "abstract-call-sites" + +STATISTIC(NumValidAbstractCallSites, + "Number of valid abstract call sites created"); +STATISTIC(NumInvalidAbstractCallSitesNoBroker, + "Number of invalid abstract call sites created (no broker)"); +STATISTIC(NumInvalidAbstractCallSitesUnknownUse, + "Number of invalid abstract call sites created (unknown use)"); +STATISTIC(NumInvalidAbstractCallSitesUnknownBroker, + "Number of invalid abstract call sites created (unknown broker)"); +STATISTIC(NumCallbackCallSites, "Number of callback call sites created"); + +/// Create an abstract call site from a use. +AbstractCallSite::AbstractCallSite(const Use *U) : CS(U->getUser()) { + + // First handle unknown users. + if (!CS) { + + // If the use is actually in a constant cast expression which itself + // has only one use, we look through the constant cast expression. + // This happens by updating the use @p U to the use of the constant + // cast expression and afterwards re-initializing CS accordingly. + if (ConstantExpr *CE = dyn_cast(U->getUser())) + if (CE->getNumUses() == 1 && CE->isCast()) { + U = &*CE->use_begin(); + CS = CallSite(U->getUser()); + } + + if (!CS) { + // Bookkeping. + NumInvalidAbstractCallSitesUnknownUse++; + return; + } + } + + // Then handle direct or indirect calls. Thus, if U is the callee of the + // call site CS it is not a callback and we are done. + if (CS.isCallee(U)) { + // Bookkeping. + NumValidAbstractCallSites++; + return; + } + + // If we cannot identify the broker function we cannot create a callback and + // invalidate the abstract call site. + Function *Callee = CS.getCalledFunction(); + if (!Callee) { + CS = static_cast(nullptr); + + // Bookkeping. + NumInvalidAbstractCallSitesNoBroker++; + return; + } + + // A collection of known broker functions. + enum KnownBrokerFunctions { + KBF_UNKNOWN, + KBF_PTHREAD_CREATE, /// < pthread_create + KBF_KMPC_FORK_CALL, /// < __kmpc_fork_call + }; + + // Use a hard-coded name matching for now. This should be extracted and + // extended with an encoding, e.g., in metadata. + KnownBrokerFunctions BrokerFn = + StringSwitch(Callee->getName()) + .Case("__kmpc_fork_call", KBF_KMPC_FORK_CALL) + .Case("pthread_create", KBF_PTHREAD_CREATE) + .Default(KBF_UNKNOWN); + + // First check if we did indeed find a known broker function or not. + if (BrokerFn == KBF_UNKNOWN) { + // If we cannot create a call back for this use we invalidate the abstract + // call site. + CS = static_cast(nullptr); + + // Bookkeping. + NumInvalidAbstractCallSitesUnknownBroker++; + return; + } + + // Bookkeping. + NumCallbackCallSites++; + + switch (BrokerFn) { + case KBF_PTHREAD_CREATE: + LLVM_DEBUG(dbgs() << "Found a callback through pthread_create!\n"); + + // This encodes that the third argument (=2) to the call is the callee of + // the callback, followed by a single pass-through argument. + CI.ParameterEncoding.append({2, 3}); + + break; + + case KBF_KMPC_FORK_CALL: + LLVM_DEBUG(dbgs() << "Found a callback through __kmpc_fork_call (#pragma " + "omp parallel [for])!\n"); + + // This encodes that the third argument (=2) to the call is the callee of + // the callback, followed by two parameters without a corresponding + // argument, and finally followed by all variadic arguments that are just + // passed through to the callee. + CI.ParameterEncoding.append({2, -1, -1}); + for (unsigned u = 3; u < CS.getNumArgOperands(); u++) + CI.ParameterEncoding.push_back(u); + + break; + + default: + llvm_unreachable("Broker function was not handled!"); + break; + } +} + Index: lib/IR/CMakeLists.txt =================================================================== --- lib/IR/CMakeLists.txt +++ lib/IR/CMakeLists.txt @@ -3,6 +3,7 @@ add_public_tablegen_target(AttributeCompatFuncTableGen) add_llvm_library(LLVMCore + AbstractCallSite.cpp AsmWriter.cpp Attributes.cpp AutoUpgrade.cpp Index: lib/Transforms/IPO/IPConstantPropagation.cpp =================================================================== --- lib/Transforms/IPO/IPConstantPropagation.cpp +++ lib/Transforms/IPO/IPConstantPropagation.cpp @@ -62,32 +62,27 @@ // Ignore blockaddress uses. if (isa(UR)) continue; - // Used by a non-instruction, or not the callee of a function, do not - // transform. - if (!isa(UR) && !isa(UR)) - return false; - - CallSite CS(cast(UR)); - if (!CS.isCallee(&U)) + // If no abstract call site was created the use is invalid. + AbstractCallSite ACS(&U); + if (!ACS) return false; // Check out all of the potentially constant arguments. Note that we don't // inspect varargs here. - CallSite::arg_iterator AI = CS.arg_begin(); Function::arg_iterator Arg = F.arg_begin(); - for (unsigned i = 0, e = ArgumentConstants.size(); i != e; - ++i, ++AI, ++Arg) { + for (unsigned i = 0, e = ArgumentConstants.size(); i != e; ++i, ++Arg) { // If this argument is known non-constant, ignore it. if (ArgumentConstants[i].second) continue; - Constant *C = dyn_cast(*AI); + Value *V = ACS.getCallArgOperand(i); + Constant *C = dyn_cast_or_null(V); if (C && ArgumentConstants[i].first == nullptr) { ArgumentConstants[i].first = C; // First constant seen. } else if (C && ArgumentConstants[i].first == C) { // Still the constant value we think it is. - } else if (*AI == &*Arg) { + } else if (V == &*Arg) { // Ignore recursive calls passing argument down. } else { // Argument became non-constant. If all arguments are non-constant now, Index: test/Transforms/IPConstantProp/openmp_parallel_for.ll =================================================================== --- /dev/null +++ test/Transforms/IPConstantProp/openmp_parallel_for.ll @@ -0,0 +1,117 @@ +; RUN: opt -S -ipconstprop < %s | FileCheck %s +; +; void bar(int, float, double); +; +; void foo(int N) { +; float p = 3; +; double q = 5; +; N = 7; +; +; #pragma omp parallel for firstprivate(q) +; for (int i = 2; i < N; i++) { +; bar(i, p, q); +; } +; } +; +; Verify the constant value of q is propagated into the outlined function. +; +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" + +%struct.ident_t = type { i32, i32, i32, i32, i8* } + +@.str = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 +@0 = private unnamed_addr global %struct.ident_t { i32 0, i32 514, i32 0, i32 0, i8* getelementptr inbounds ([23 x i8], [23 x i8]* @.str, i32 0, i32 0) }, align 8 +@1 = private unnamed_addr global %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([23 x i8], [23 x i8]* @.str, i32 0, i32 0) }, align 8 + +define dso_local void @foo(i32 %N) { +entry: + %N.addr = alloca i32, align 4 + %p = alloca float, align 4 + store i32 %N, i32* %N.addr, align 4 + store float 3.000000e+00, float* %p, align 4 + store i32 7, i32* %N.addr, align 4 + call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* nonnull @1, i32 3, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i32*, float*, i64)* @.omp_outlined. to void (i32*, i32*, ...)*), i32* nonnull %N.addr, float* nonnull %p, i64 4617315517961601024) + ret void +} + +define internal void @.omp_outlined.(i32* noalias %.global_tid., i32* noalias %.bound_tid., i32* dereferenceable(4) %N, float* dereferenceable(4) %p, i64 %q) { +entry: + %q.addr = alloca i64, align 8 + %.omp.lb = alloca i32, align 4 + %.omp.ub = alloca i32, align 4 + %.omp.stride = alloca i32, align 4 + %.omp.is_last = alloca i32, align 4 +; CHECK: store i64 4617315517961601024, i64* %q.addr, align 8 + store i64 %q, i64* %q.addr, align 8 + %conv = bitcast i64* %q.addr to double* + %tmp = load i32, i32* %N, align 4 + %sub3 = add nsw i32 %tmp, -3 + %cmp = icmp sgt i32 %tmp, 2 + br i1 %cmp, label %omp.precond.then, label %omp.precond.end + +omp.precond.then: ; preds = %entry + store i32 0, i32* %.omp.lb, align 4 + store i32 %sub3, i32* %.omp.ub, align 4 + store i32 1, i32* %.omp.stride, align 4 + store i32 0, i32* %.omp.is_last, align 4 + %tmp5 = load i32, i32* %.global_tid., align 4 + call void @__kmpc_for_static_init_4(%struct.ident_t* nonnull @0, i32 %tmp5, i32 34, i32* nonnull %.omp.is_last, i32* nonnull %.omp.lb, i32* nonnull %.omp.ub, i32* nonnull %.omp.stride, i32 1, i32 1) + %tmp6 = load i32, i32* %.omp.ub, align 4 + %cmp6 = icmp sgt i32 %tmp6, %sub3 + br i1 %cmp6, label %cond.true, label %cond.false + +cond.true: ; preds = %omp.precond.then + br label %cond.end + +cond.false: ; preds = %omp.precond.then + %tmp7 = load i32, i32* %.omp.ub, align 4 + br label %cond.end + +cond.end: ; preds = %cond.false, %cond.true + %cond = phi i32 [ %sub3, %cond.true ], [ %tmp7, %cond.false ] + store i32 %cond, i32* %.omp.ub, align 4 + %tmp8 = load i32, i32* %.omp.lb, align 4 + br label %omp.inner.for.cond + +omp.inner.for.cond: ; preds = %omp.inner.for.inc, %cond.end + %.omp.iv.0 = phi i32 [ %tmp8, %cond.end ], [ %add11, %omp.inner.for.inc ] + %tmp9 = load i32, i32* %.omp.ub, align 4 + %cmp8 = icmp sgt i32 %.omp.iv.0, %tmp9 + br i1 %cmp8, label %omp.inner.for.cond.cleanup, label %omp.inner.for.body + +omp.inner.for.cond.cleanup: ; preds = %omp.inner.for.cond + br label %omp.inner.for.end + +omp.inner.for.body: ; preds = %omp.inner.for.cond + %add10 = add nsw i32 %.omp.iv.0, 2 + %tmp10 = load float, float* %p, align 4 + %tmp11 = load double, double* %conv, align 8 + call void @bar(i32 %add10, float %tmp10, double %tmp11) + br label %omp.body.continue + +omp.body.continue: ; preds = %omp.inner.for.body + br label %omp.inner.for.inc + +omp.inner.for.inc: ; preds = %omp.body.continue + %add11 = add nsw i32 %.omp.iv.0, 1 + br label %omp.inner.for.cond + +omp.inner.for.end: ; preds = %omp.inner.for.cond.cleanup + br label %omp.loop.exit + +omp.loop.exit: ; preds = %omp.inner.for.end + %tmp12 = load i32, i32* %.global_tid., align 4 + call void @__kmpc_for_static_fini(%struct.ident_t* nonnull @0, i32 %tmp12) + br label %omp.precond.end + +omp.precond.end: ; preds = %omp.loop.exit, %entry + ret void +} + +declare dso_local void @__kmpc_for_static_init_4(%struct.ident_t*, i32, i32, i32*, i32*, i32*, i32*, i32, i32) + +declare dso_local void @bar(i32, float, double) + +declare dso_local void @__kmpc_for_static_fini(%struct.ident_t*, i32) + +declare dso_local void @__kmpc_fork_call(%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) Index: test/Transforms/IPConstantProp/pthreads.ll =================================================================== --- /dev/null +++ test/Transforms/IPConstantProp/pthreads.ll @@ -0,0 +1,46 @@ +; RUN: opt -ipconstprop -S < %s | FileCheck %s +; +; #include +; +; void *GlobalVPtr; +; +; static void *foo(void *arg) { return arg; } +; static void *bar(void *arg) { return arg; } +; +; int main() { +; pthread_t thread; +; pthread_create(&thread, NULL, foo, NULL); +; pthread_create(&thread, NULL, bar, &GlobalVPtr); +; return 0; +; } +; +; Verify the constant values NULL and &GlobalVPtr are propagated into foo and +; bar, respectively. +; +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" + +%union.pthread_attr_t = type { i64, [48 x i8] } + +@GlobalVPtr = common dso_local global i8* null, align 8 + +define dso_local i32 @main() { +entry: + %thread = alloca i64, align 8 + %call = call i32 @pthread_create(i64* nonnull %thread, %union.pthread_attr_t* null, i8* (i8*)* nonnull @foo, i8* null) + %call1 = call i32 @pthread_create(i64* nonnull %thread, %union.pthread_attr_t* null, i8* (i8*)* nonnull @bar, i8* bitcast (i8** @GlobalVPtr to i8*)) + ret i32 0 +} + +declare dso_local i32 @pthread_create(i64*, %union.pthread_attr_t*, i8* (i8*)*, i8*) + +define internal i8* @foo(i8* %arg) { +entry: +; CHECK: ret i8* null + ret i8* %arg +} + +define internal i8* @bar(i8* %arg) { +entry: +; CHECK: ret i8* bitcast (i8** @GlobalVPtr to i8*) + ret i8* %arg +}