diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -22,6 +22,7 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" +#include "llvm/IR/Assumptions.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" @@ -1450,6 +1451,12 @@ } }; +/// The "omp_no_external_caller_in_target_region" assumption guarantees that +/// there are no external caller of a function which are inside an OpenMP +/// target region. +static KnownAssumptionString + NoExternalCallerInTargetRegion("omp_no_external_caller_in_target_region"); + Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { if (!OMPInfoCache.ModuleSlice.count(&F)) return nullptr; @@ -1469,7 +1476,8 @@ } CachedKernel = nullptr; - if (!F.hasLocalLinkage()) + if (!F.hasLocalLinkage() && + !hasAssumption(F, NoExternalCallerInTargetRegion)) return nullptr; } diff --git a/llvm/test/Transforms/OpenMP/gpu_state_machine_function_ptr_replacement.ll b/llvm/test/Transforms/OpenMP/gpu_state_machine_function_ptr_replacement.ll --- a/llvm/test/Transforms/OpenMP/gpu_state_machine_function_ptr_replacement.ll +++ b/llvm/test/Transforms/OpenMP/gpu_state_machine_function_ptr_replacement.ll @@ -7,12 +7,18 @@ ; #pragma omp parallel ; { } ; } +; __attribute__((assume("no_external_callers"))) +; void baz(void) { +; #pragma omp parallel +; { } +; } ; void foo(void) { ; #pragma omp target teams ; { ; #pragma omp parallel ; {} ; bar(); +; baz(); ; #pragma omp parallel ; {} ; } @@ -23,13 +29,16 @@ ; another kernel. ; CHECK-DAG: @__omp_outlined__1_wrapper.ID = private constant i8 undef +; CHECK-DAG: @__omp_outlined__2b_wrapper.ID = private constant i8 undef ; CHECK-DAG: @__omp_outlined__3_wrapper.ID = private constant i8 undef ; CHECK-DAG: icmp eq i8* %5, @__omp_outlined__1_wrapper.ID +; CHECK-DAG: icmp eq i8* %b6, @__omp_outlined__2b_wrapper.ID ; CHECK-DAG: icmp eq i8* %7, @__omp_outlined__3_wrapper.ID ; CHECK-DAG: call void @__kmpc_kernel_prepare_parallel(i8* @__omp_outlined__1_wrapper.ID) -; CHECK-DAG: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__2_wrapper to i8*)) +; CHECK-DAG: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__2a_wrapper to i8*)) +; CHECK-DAG: call void @__kmpc_kernel_prepare_parallel(i8* @__omp_outlined__2b_wrapper.ID) ; CHECK-DAG: call void @__kmpc_kernel_prepare_parallel(i8* @__omp_outlined__3_wrapper.ID) @@ -69,11 +78,20 @@ .check.next: ; preds = %.execute.parallel %6 = load i8*, i8** %work_fn, align 8 - %work_match1 = icmp eq i8* %6, bitcast (void ()* @__omp_outlined__2_wrapper to i8*) - br i1 %work_match1, label %.execute.fn2, label %.check.next3 + %work_match1 = icmp eq i8* %6, bitcast (void ()* @__omp_outlined__2a_wrapper to i8*) + br i1 %work_match1, label %.execute.fn2a, label %.check.next2 + +.execute.fn2a: ; preds = %.check.next + call void @__omp_outlined__2a_wrapper() + br label %.terminate.parallel + +.check.next2: ; preds = %.execute.parallel + %b6 = load i8*, i8** %work_fn, align 8 + %work_match1b = icmp eq i8* %b6, bitcast (void ()* @__omp_outlined__2b_wrapper to i8*) + br i1 %work_match1b, label %.execute.fn2b, label %.check.next3 -.execute.fn2: ; preds = %.check.next - call void @__omp_outlined__2_wrapper() +.execute.fn2b: ; preds = %.check.next + call void @__omp_outlined__2b_wrapper() br label %.terminate.parallel .check.next3: ; preds = %.check.next @@ -111,6 +129,7 @@ define internal void @__omp_outlined__() { call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__1_wrapper to i8*)) call void @bar() + call void @baz() call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__3_wrapper to i8*)) ret void } @@ -125,11 +144,20 @@ } define hidden void @bar() { - call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__2_wrapper to i8*)) + call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__2a_wrapper to i8*)) + ret void +} + +define hidden void @baz() #0 { + call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__2b_wrapper to i8*)) + ret void +} + +define internal void @__omp_outlined__2a_wrapper() { ret void } -define internal void @__omp_outlined__2_wrapper() { +define internal void @__omp_outlined__2b_wrapper() { ret void } @@ -147,6 +175,7 @@ declare i32 @__kmpc_global_thread_num(%struct.ident_t* nocapture readnone) +attributes #0 = { "llvm.assume"="abc,omp_no_external_caller_in_target_region,123" } !nvvm.annotations = !{!0}