diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -97,23 +97,6 @@ #include "llvm/Frontend/OpenMP/OMPKinds.def" } - // Add information if the runtime function takes a callback function - if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) { - if (!Fn->hasMetadata(LLVMContext::MD_callback)) { - LLVMContext &Ctx = Fn->getContext(); - MDBuilder MDB(Ctx); - // Annotate the callback behavior of the runtime function: - // - The callback callee is argument number 2 (microtask). - // - The first two arguments of the callback callee are unknown (-1). - // - All variadic arguments to the runtime function are passed to the - // callback callee. - Fn->addMetadata( - LLVMContext::MD_callback, - *MDNode::get(Ctx, {MDB.createCallbackEncoding( - 2, {-1, -1}, /* VarArgsArePassed */ true)})); - } - } - LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName() << " with type " << *Fn->getFunctionType() << "\n"); addAttributes(FnID, *Fn); @@ -123,6 +106,36 @@ << " with type " << *Fn->getFunctionType() << "\n"); } + // Add information if the runtime function takes a callback function + if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) { + if (!Fn->hasMetadata(LLVMContext::MD_callback)) { + LLVMContext &Ctx = Fn->getContext(); + MDBuilder MDB(Ctx); + // Annotate the callback behavior of the runtime function: + // - The callback callee is argument number 2 (microtask). + // - The first two arguments of the callback callee are unknown (-1). + // - All variadic arguments to the runtime function are passed to the + // callback callee. + Fn->addMetadata( + LLVMContext::MD_callback, + *MDNode::get(Ctx, {MDB.createCallbackEncoding( + 2, {-1, -1}, /* VarArgsArePassed */ true)})); + } + } else if (FnID == OMPRTL___kmpc_parallel_51) { + if (!Fn->hasMetadata(LLVMContext::MD_callback)) { + LLVMContext &Ctx = Fn->getContext(); + MDBuilder MDB(Ctx); + // Annotate the callback behavior of the runtime function: + // - The callback callee is argument number 5 (outlined function). + // - The first two arguments of the callback callee are unknown (-1). + // - Argument 7, (args) is passed next, no variadic arguments are used. + Fn->addMetadata(LLVMContext::MD_callback, + *MDNode::get(Ctx, {MDB.createCallbackEncoding( + 5, {-1, -1, 7}, + /* VarArgsArePassed */ false)})); + } + } + assert(Fn && "Failed to create OpenMP runtime function"); // Cast the function to the expected type if necessary diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -150,11 +150,11 @@ struct OMPInformationCache : public InformationCache { OMPInformationCache(Module &M, AnalysisGetter &AG, BumpPtrAllocator &Allocator, SetVector &CGSCC, - SmallPtrSetImpl &Kernels) - : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M), + SmallPtrSetImpl &Kernels, + OpenMPIRBuilder &OMPBuilder) + : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(OMPBuilder), Kernels(Kernels) { - OMPBuilder.initialize(); initializeRuntimeFunctions(); initializeInternalControlVars(); } @@ -285,7 +285,7 @@ }; /// An OpenMP-IR-Builder instance - OpenMPIRBuilder OMPBuilder; + OpenMPIRBuilder &OMPBuilder; /// Map from runtime function kind to the runtime function description. EnumeratedArray(M).getManager(); KernelSet Kernels = getDeviceKernels(M); + OpenMPIRBuilder OMPBuilder(M); + OMPBuilder.initialize(); + // Make sure we have the callback annotations attached to the + // __kmpc_parallel_51 function declaration/definition. + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_parallel_51); + auto IsCalled = [&](Function &F) { if (Kernels.contains(&F)) return true; @@ -4497,9 +4503,9 @@ BumpPtrAllocator Allocator; CallGraphUpdater CGUpdater; - SetVector Functions(SCC.begin(), SCC.end()); - OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels); + OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels, + OMPBuilder); unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false, @@ -4553,9 +4559,11 @@ CallGraphUpdater CGUpdater; CGUpdater.initialize(CG, C, AM, UR); + OpenMPIRBuilder OMPBuilder(M); + OMPBuilder.initialize(); SetVector Functions(SCC.begin(), SCC.end()); - OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, - /*CGSCC*/ Functions, Kernels); + OMPInformationCache InfoCache(M, AG, Allocator, + /*CGSCC*/ Functions, Kernels, OMPBuilder); unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, @@ -4623,9 +4631,10 @@ AnalysisGetter AG; SetVector Functions(SCC.begin(), SCC.end()); BumpPtrAllocator Allocator; - OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, - Allocator, - /*CGSCC*/ Functions, Kernels); + OpenMPIRBuilder OMPBuilder(M); + OMPBuilder.initialize(); + OMPInformationCache InfoCache(M, AG, Allocator, + /*CGSCC*/ Functions, Kernels, OMPBuilder); unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,