diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -52,6 +52,9 @@ "Number of OpenMP runtime function uses identified"); STATISTIC(NumOpenMPTargetRegionKernels, "Number of OpenMP target region entry points (=kernels) identified"); +STATISTIC( + NumOpenMPParallelRegionsReplacedInGPUStateMachine, + "Number of OpenMP parallel regions replaced with ID in GPU state machines"); #if !defined(NDEBUG) static constexpr auto TAG = "[" DEBUG_TYPE "]"; @@ -496,6 +499,8 @@ if (PrintOpenMPKernels) printKernels(); + Changed |= rewriteDeviceCodeStateMachine(); + Changed |= runAttributor(); // Recollect uses, in case Attributor deleted any. @@ -849,6 +854,31 @@ AddUserArgs(*GTIdArgs[u]); } + /// Kernel (=GPU) optimizations and utility functions + /// + ///{{ + + /// Check if \p F is a kernel, hence entry point for target offloading. + bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); } + + /// Cache to remember the unique kernel for a function. + DenseMap> UniqueKernelMap; + + /// Find the unique kernel that will execute \p F, if any. + Kernel getUniqueKernelFor(Function &F); + + /// Find the unique kernel that will execute \p I, if any. + Kernel getUniqueKernelFor(Instruction &I) { + return getUniqueKernelFor(*I.getFunction()); + } + + /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in + /// the cases we can avoid taking the address of a function. + bool rewriteDeviceCodeStateMachine(); + + /// + ///}} + /// Emit a remark generically /// /// This template function can be used to generically emit a remark. The @@ -930,6 +960,140 @@ } }; +Kernel OpenMPOpt::getUniqueKernelFor(Function &F) { + if (!OMPInfoCache.ModuleSlice.count(&F)) + return nullptr; + + // Use a scope to keep the lifetime of the CachedKernel short. + { + Optional &CachedKernel = UniqueKernelMap[&F]; + if (CachedKernel) + return *CachedKernel; + + // TODO: We should use an AA to create an (optimistic and callback + // call-aware) call graph. For now we stick to simple patterns that + // are less powerful, basically the worst fixpoint. + if (isKernel(F)) { + CachedKernel = Kernel(&F); + return *CachedKernel; + } + + CachedKernel = nullptr; + if (!F.hasLocalLinkage()) + return nullptr; + } + + auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel { + if (auto *Cmp = dyn_cast(U.getUser())) { + // Allow use in equality comparisons. + if (Cmp->isEquality()) + return getUniqueKernelFor(*Cmp); + return nullptr; + } + if (auto *CB = dyn_cast(U.getUser())) { + // Allow direct calls. + if (CB->isCallee(&U)) + return getUniqueKernelFor(*CB); + // Allow the use in __kmpc_kernel_prepare_parallel calls. + if (Function *Callee = CB->getCalledFunction()) + if (Callee->getName() == "__kmpc_kernel_prepare_parallel") + return getUniqueKernelFor(*CB); + return nullptr; + } + // Disallow every other use. + return nullptr; + }; + + // TODO: In the future we want to track more than just a unique kernel. + SmallPtrSet PotentialKernels; + foreachUse(F, [&](const Use &U) { + PotentialKernels.insert(GetUniqueKernelForUse(U)); + }); + + Kernel K = nullptr; + if (PotentialKernels.size() == 1) + K = *PotentialKernels.begin(); + + // Cache the result. + UniqueKernelMap[&F] = K; + + return K; +} + +bool OpenMPOpt::rewriteDeviceCodeStateMachine() { + constexpr unsigned KMPC_KERNEL_PARALLEL_WORK_FN_PTR_ARG_NO = 0; + + OMPInformationCache::RuntimeFunctionInfo &KernelPrepareParallelRFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_kernel_prepare_parallel]; + + bool Changed = false; + if (!KernelPrepareParallelRFI) + return Changed; + + for (Function *F : SCC) { + + // Check if the function is uses in a __kmpc_kernel_prepare_parallel call at + // all. + bool UnknownUse = false; + unsigned NumDirectCalls = 0; + + SmallVector ToBeReplacedStateMachineUses; + foreachUse(*F, [&](Use &U) { + if (auto *CB = dyn_cast(U.getUser())) + if (CB->isCallee(&U)) { + ++NumDirectCalls; + return; + } + + if (auto *Cmp = dyn_cast(U.getUser())) { + ToBeReplacedStateMachineUses.push_back(&U); + return; + } + if (CallInst *CI = OpenMPOpt::getCallIfRegularCall( + *U.getUser(), &KernelPrepareParallelRFI)) { + ToBeReplacedStateMachineUses.push_back(&U); + return; + } + UnknownUse = true; + }); + + // If this ever hits, we should investigate. + if (UnknownUse || NumDirectCalls != 1) + continue; + + // TODO: This is not a necessary restriction and should be lifted. + if (ToBeReplacedStateMachineUses.size() != 2) + continue; + + // Even if we have __kmpc_kernel_prepare_parallel calls, we (for now) give + // up if the function is not called from a unique kernel. + Kernel K = getUniqueKernelFor(*F); + if (!K) + continue; + + // We now know F is a parallel body function called only from the kernel K. + // We also identified the state machine uses in which we replace the + // function pointer by a new global symbol for identification purposes. This + // ensures only direct calls to the function are left. + + Module &M = *F->getParent(); + Type *Int8Ty = Type::getInt8Ty(M.getContext()); + + auto *ID = new GlobalVariable( + M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage, + UndefValue::get(Int8Ty), F->getName() + ".ID"); + + for (Use *U : ToBeReplacedStateMachineUses) + U->set(ConstantExpr::getBitCast(ID, U->get()->getType())); + + ++NumOpenMPParallelRegionsReplacedInGPUStateMachine; + + Changed = true; + } + + return Changed; +} + /// Abstract Attribute for tracking ICV values. struct AAICVTracker : public StateWrapper { using Base = StateWrapper; diff --git a/llvm/test/Transforms/OpenMP/gpu_state_machine_function_ptr_replacement.ll b/llvm/test/Transforms/OpenMP/gpu_state_machine_function_ptr_replacement.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/OpenMP/gpu_state_machine_function_ptr_replacement.ll @@ -0,0 +1,153 @@ +; RUN: opt -S -passes=openmpopt -pass-remarks=openmp-opt -openmp-print-gpu-kernels < %s | FileCheck %s +; RUN: opt -S -openmpopt -pass-remarks=openmp-opt -openmp-print-gpu-kernels < %s | FileCheck %s + +; C input used for this test: + +; void bar(void) { +; #pragma omp parallel +; { } +; } +; void foo(void) { +; #pragma omp target teams +; { +; #pragma omp parallel +; {} +; bar(); +; #pragma omp parallel +; {} +; } +; } + +; Verify we replace the function pointer uses for the first and last outlined +; region (1 and 3) but not for the middle one (2) because it could be called from +; another kernel. + +; CHECK-DAG: @__omp_outlined__1_wrapper.ID = private constant i8 undef +; CHECK-DAG: @__omp_outlined__3_wrapper.ID = private constant i8 undef + +; CHECK-DAG: icmp eq i8* %5, @__omp_outlined__1_wrapper.ID +; CHECK-DAG: icmp eq i8* %7, @__omp_outlined__3_wrapper.ID + +; CHECK-DAG: call void @__kmpc_kernel_prepare_parallel(i8* @__omp_outlined__1_wrapper.ID) +; CHECK-DAG: call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__2_wrapper to i8*)) +; CHECK-DAG: call void @__kmpc_kernel_prepare_parallel(i8* @__omp_outlined__3_wrapper.ID) + + +%struct.ident_t = type { i32, i32, i32, i32, i8* } + +define internal void @__omp_offloading_35_a1e179_foo_l7_worker() { +entry: + %work_fn = alloca i8*, align 8 + %exec_status = alloca i8, align 1 + store i8* null, i8** %work_fn, align 8 + store i8 0, i8* %exec_status, align 1 + br label %.await.work + +.await.work: ; preds = %.barrier.parallel, %entry + call void @__kmpc_barrier_simple_spmd(%struct.ident_t* null, i32 0) + %0 = call i1 @__kmpc_kernel_parallel(i8** %work_fn) + %1 = zext i1 %0 to i8 + store i8 %1, i8* %exec_status, align 1 + %2 = load i8*, i8** %work_fn, align 8 + %should_terminate = icmp eq i8* %2, null + br i1 %should_terminate, label %.exit, label %.select.workers + +.select.workers: ; preds = %.await.work + %3 = load i8, i8* %exec_status, align 1 + %is_active = icmp ne i8 %3, 0 + br i1 %is_active, label %.execute.parallel, label %.barrier.parallel + +.execute.parallel: ; preds = %.select.workers + %4 = call i32 @__kmpc_global_thread_num(%struct.ident_t* null) + %5 = load i8*, i8** %work_fn, align 8 + %work_match = icmp eq i8* %5, bitcast (void ()* @__omp_outlined__1_wrapper to i8*) + br i1 %work_match, label %.execute.fn, label %.check.next + +.execute.fn: ; preds = %.execute.parallel + call void @__omp_outlined__1_wrapper() + br label %.terminate.parallel + +.check.next: ; preds = %.execute.parallel + %6 = load i8*, i8** %work_fn, align 8 + %work_match1 = icmp eq i8* %6, bitcast (void ()* @__omp_outlined__2_wrapper to i8*) + br i1 %work_match1, label %.execute.fn2, label %.check.next3 + +.execute.fn2: ; preds = %.check.next + call void @__omp_outlined__2_wrapper() + br label %.terminate.parallel + +.check.next3: ; preds = %.check.next + %7 = load i8*, i8** %work_fn, align 8 + %work_match4 = icmp eq i8* %7, bitcast (void ()* @__omp_outlined__3_wrapper to i8*) + br i1 %work_match4, label %.execute.fn5, label %.check.next6 + +.execute.fn5: ; preds = %.check.next3 + call void @__omp_outlined__3_wrapper() + br label %.terminate.parallel + +.check.next6: ; preds = %.check.next3 + %8 = bitcast i8* %2 to void ()* + call void %8() + br label %.terminate.parallel + +.terminate.parallel: ; preds = %.check.next6, %.execute.fn5, %.execute.fn2, %.execute.fn + call void @__kmpc_kernel_end_parallel() + br label %.barrier.parallel + +.barrier.parallel: ; preds = %.terminate.parallel, %.select.workers + call void @__kmpc_barrier_simple_spmd(%struct.ident_t* null, i32 0) + br label %.await.work + +.exit: ; preds = %.await.work + ret void +} + +define weak void @__omp_offloading_35_a1e179_foo_l7() { + call void @__omp_offloading_35_a1e179_foo_l7_worker() + call void @__omp_outlined__() + ret void +} + +define internal void @__omp_outlined__() { + call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__1_wrapper to i8*)) + call void @bar() + call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__3_wrapper to i8*)) + ret void +} + +define internal void @__omp_outlined__1() { + ret void +} + +define internal void @__omp_outlined__1_wrapper() { + call void @__omp_outlined__1() + ret void +} + +define hidden void @bar() { + call void @__kmpc_kernel_prepare_parallel(i8* bitcast (void ()* @__omp_outlined__2_wrapper to i8*)) + ret void +} + +define internal void @__omp_outlined__2_wrapper() { + ret void +} + +define internal void @__omp_outlined__3_wrapper() { + ret void +} + +declare void @__kmpc_kernel_prepare_parallel(i8* %WorkFn) + +declare zeroext i1 @__kmpc_kernel_parallel(i8** nocapture %WorkFn) + +declare void @__kmpc_kernel_end_parallel() + +declare void @__kmpc_barrier_simple_spmd(%struct.ident_t* nocapture readnone %loc_ref, i32 %tid) + +declare i32 @__kmpc_global_thread_num(%struct.ident_t* nocapture readnone) + + +!nvvm.annotations = !{!0} + +!0 = !{void ()* @__omp_offloading_35_a1e179_foo_l7, !"kernel", i32 1}