diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -97,23 +97,6 @@ #include "llvm/Frontend/OpenMP/OMPKinds.def" } - // Add information if the runtime function takes a callback function - if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) { - if (!Fn->hasMetadata(LLVMContext::MD_callback)) { - LLVMContext &Ctx = Fn->getContext(); - MDBuilder MDB(Ctx); - // Annotate the callback behavior of the runtime function: - // - The callback callee is argument number 2 (microtask). - // - The first two arguments of the callback callee are unknown (-1). - // - All variadic arguments to the runtime function are passed to the - // callback callee. - Fn->addMetadata( - LLVMContext::MD_callback, - *MDNode::get(Ctx, {MDB.createCallbackEncoding( - 2, {-1, -1}, /* VarArgsArePassed */ true)})); - } - } - LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName() << " with type " << *Fn->getFunctionType() << "\n"); addAttributes(FnID, *Fn); @@ -123,6 +106,36 @@ << " with type " << *Fn->getFunctionType() << "\n"); } + // Add information if the runtime function takes a callback function + if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) { + if (!Fn->hasMetadata(LLVMContext::MD_callback)) { + LLVMContext &Ctx = Fn->getContext(); + MDBuilder MDB(Ctx); + // Annotate the callback behavior of the runtime function: + // - The callback callee is argument number 2 (microtask). + // - The first two arguments of the callback callee are unknown (-1). + // - All variadic arguments to the runtime function are passed to the + // callback callee. + Fn->addMetadata( + LLVMContext::MD_callback, + *MDNode::get(Ctx, {MDB.createCallbackEncoding( + 2, {-1, -1}, /* VarArgsArePassed */ true)})); + } + } else if (FnID == OMPRTL___kmpc_parallel_51) { + if (!Fn->hasMetadata(LLVMContext::MD_callback)) { + LLVMContext &Ctx = Fn->getContext(); + MDBuilder MDB(Ctx); + // Annotate the callback behavior of the runtime function: + // - The callback callee is argument number 5 (outlined function). + // - The first two arguments of the callback callee are unknown (-1). + // - Argument 7, (args) is passed next, no variadic arguments are used. + Fn->addMetadata(LLVMContext::MD_callback, + *MDNode::get(Ctx, {MDB.createCallbackEncoding( + 5, {-1, -1, 7}, + /* VarArgsArePassed */ false)})); + } + } + assert(Fn && "Failed to create OpenMP runtime function"); // Cast the function to the expected type if necessary diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -150,11 +150,11 @@ struct OMPInformationCache : public InformationCache { OMPInformationCache(Module &M, AnalysisGetter &AG, BumpPtrAllocator &Allocator, SetVector &CGSCC, - SmallPtrSetImpl &Kernels) - : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M), + SmallPtrSetImpl &Kernels, + OpenMPIRBuilder &OMPBuilder) + : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(OMPBuilder), Kernels(Kernels) { - OMPBuilder.initialize(); initializeRuntimeFunctions(); initializeInternalControlVars(); } @@ -285,7 +285,7 @@ }; /// An OpenMP-IR-Builder instance - OpenMPIRBuilder OMPBuilder; + OpenMPIRBuilder &OMPBuilder; /// Map from runtime function kind to the runtime function description. EnumeratedArray(M).getManager(); KernelSet Kernels = getDeviceKernels(M); + OpenMPIRBuilder OMPBuilder(M); + OMPBuilder.initialize(); + // Make sure we have the callback annotations attached to the + // __kmpc_parallel_51 function declaration/definition. + OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_parallel_51); + auto IsCalled = [&](Function &F) { if (Kernels.contains(&F)) return true; @@ -4497,9 +4503,9 @@ BumpPtrAllocator Allocator; CallGraphUpdater CGUpdater; - SetVector Functions(SCC.begin(), SCC.end()); - OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels); + OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels, + OMPBuilder); unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false, @@ -4553,9 +4559,11 @@ CallGraphUpdater CGUpdater; CGUpdater.initialize(CG, C, AM, UR); + OpenMPIRBuilder OMPBuilder(M); + OMPBuilder.initialize(); SetVector Functions(SCC.begin(), SCC.end()); - OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, - /*CGSCC*/ Functions, Kernels); + OMPInformationCache InfoCache(M, AG, Allocator, + /*CGSCC*/ Functions, Kernels, OMPBuilder); unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, @@ -4623,9 +4631,10 @@ AnalysisGetter AG; SetVector Functions(SCC.begin(), SCC.end()); BumpPtrAllocator Allocator; - OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, - Allocator, - /*CGSCC*/ Functions, Kernels); + OpenMPIRBuilder OMPBuilder(M); + OMPBuilder.initialize(); + OMPInformationCache InfoCache(M, AG, Allocator, + /*CGSCC*/ Functions, Kernels, OMPBuilder); unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32; Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, diff --git a/llvm/test/Transforms/OpenMP/parallel_level_fold.ll b/llvm/test/Transforms/OpenMP/parallel_level_fold.ll --- a/llvm/test/Transforms/OpenMP/parallel_level_fold.ll +++ b/llvm/test/Transforms/OpenMP/parallel_level_fold.ll @@ -108,7 +108,7 @@ define internal void @__kmpc_parallel_51(%struct.ident_t*, i32, i32, i32, i32, i8*, i8*, i8**, i64) { ; CHECK-LABEL: define {{[^@]+}}@__kmpc_parallel_51 -; CHECK-SAME: (%struct.ident_t* noalias nocapture nofree readnone align 1073741824 [[TMP0:%.*]], i32 [[TMP1:%.*]], i32 [[TMP2:%.*]], i32 [[TMP3:%.*]], i32 [[TMP4:%.*]], i8* noalias nocapture nofree readnone align 1073741824 [[TMP5:%.*]], i8* noalias nocapture nofree readnone align 1073741824 [[TMP6:%.*]], i8** noalias nocapture nofree readnone align 1073741824 [[TMP7:%.*]], i64 [[TMP8:%.*]]) #[[ATTR0:[0-9]+]] { +; CHECK-SAME: (%struct.ident_t* noalias nocapture nofree readnone align 1073741824 [[TMP0:%.*]], i32 [[TMP1:%.*]], i32 [[TMP2:%.*]], i32 [[TMP3:%.*]], i32 [[TMP4:%.*]], i8* noalias nocapture nofree readnone align 1073741824 [[TMP5:%.*]], i8* noalias nocapture nofree readnone align 1073741824 [[TMP6:%.*]], i8** noalias nocapture nofree readnone align 1073741824 [[TMP7:%.*]], i64 [[TMP8:%.*]]) #[[ATTR0:[0-9]+]] !callback !5 { ; CHECK-NEXT: call void @parallel_helper() ; CHECK-NEXT: ret void ; @@ -149,4 +149,6 @@ ; CHECK: [[META2:![0-9]+]] = !{void ()* @none_spmd, !"kernel", i32 1} ; CHECK: [[META3:![0-9]+]] = !{void ()* @spmd, !"kernel", i32 1} ; CHECK: [[META4:![0-9]+]] = !{void ()* @parallel, !"kernel", i32 1} +; CHECK: [[META5:![0-9]+]] = !{!6} +; CHECK: [[META6:![0-9]+]] = !{i64 5, i64 -1, i64 -1, i64 7, i1 false} ;.