diff --git a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h --- a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h +++ b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h @@ -37,13 +37,25 @@ /// OpenMP optimizations pass. class OpenMPOptPass : public PassInfoMixin { public: + OpenMPOptPass() : LTOPhase(ThinOrFullLTOPhase::None) {} + OpenMPOptPass(ThinOrFullLTOPhase LTOPhase) : LTOPhase(LTOPhase) {} + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); + +private: + const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None; }; class OpenMPOptCGSCCPass : public PassInfoMixin { public: + OpenMPOptCGSCCPass() : LTOPhase(ThinOrFullLTOPhase::None) {} + OpenMPOptCGSCCPass(ThinOrFullLTOPhase LTOPhase) : LTOPhase(LTOPhase) {} + PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR); + +private: + const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None; }; } // end namespace llvm diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -1604,7 +1604,7 @@ } // Try to run OpenMP optimizations, quick no-op if no OpenMP metadata present. - MPM.addPass(OpenMPOptPass()); + MPM.addPass(OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink)); // Remove unused virtual tables to improve the quality of code generated by // whole-program devirtualization and bitset lowering. @@ -1808,7 +1808,8 @@ addVectorPasses(Level, MainFPM, /* IsFullLTO */ true); // Run the OpenMPOpt CGSCC pass again late. - MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(OpenMPOptCGSCCPass())); + MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor( + OpenMPOptCGSCCPass(ThinOrFullLTOPhase::FullLTOPostLink))); invokePeepholeEPCallbacks(MainFPM, Level); MainFPM.addPass(JumpThreadingPass()); diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -44,6 +44,7 @@ MODULE_PASS("attributor", AttributorPass()) MODULE_PASS("annotation2metadata", Annotation2MetadataPass()) MODULE_PASS("openmp-opt", OpenMPOptPass()) +MODULE_PASS("openmp-opt-postlink", OpenMPOptPass(ThinOrFullLTOPhase::FullLTOPostLink)) MODULE_PASS("called-value-propagation", CalledValuePropagationPass()) MODULE_PASS("canonicalize-aliases", CanonicalizeAliasesPass()) MODULE_PASS("cg-profile", CGProfilePass()) diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -188,9 +188,9 @@ struct OMPInformationCache : public InformationCache { OMPInformationCache(Module &M, AnalysisGetter &AG, BumpPtrAllocator &Allocator, SetVector *CGSCC, - KernelSet &Kernels) + KernelSet &Kernels, bool OpenMPPostLink) : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M), - Kernels(Kernels) { + Kernels(Kernels), OpenMPPostLink(OpenMPPostLink) { OMPBuilder.initialize(); initializeRuntimeFunctions(M); @@ -448,6 +448,24 @@ CI->setCallingConv(Fn->getCallingConv()); } + // Helper function to determine if it's legal to create a call to the runtime + // functions. + bool runtimeFnsAvailable(ArrayRef Fns) { + // We can always emit calls if we haven't yet linked in the runtime. + if (!OpenMPPostLink) + return true; + + // Once the runtime has been already been linked in we cannot emit calls to + // any undefined functions. + for (RuntimeFunction Fn : Fns) { + RuntimeFunctionInfo &RFI = RFIs[Fn]; + + if (RFI.Declaration && RFI.Declaration->isDeclaration()) + return false; + } + return true; + } + /// Helper to initialize all runtime function information for those defined /// in OpenMPKinds.def. void initializeRuntimeFunctions(Module &M) { @@ -523,6 +541,9 @@ /// Collection of known OpenMP runtime functions.. DenseSet RTLFunctions; + + /// Indicates if we have already linked in the OpenMP device library. + bool OpenMPPostLink = false; }; template @@ -1412,7 +1433,10 @@ Changed |= WasSplit; return WasSplit; }; - RFI.foreachUse(SCC, SplitMemTransfers); + if (OMPInfoCache.runtimeFnsAvailable( + {OMPRTL___tgt_target_data_begin_mapper_issue, + OMPRTL___tgt_target_data_begin_mapper_wait})) + RFI.foreachUse(SCC, SplitMemTransfers); return Changed; } @@ -3912,6 +3936,12 @@ bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) { auto &OMPInfoCache = static_cast(A.getInfoCache()); + // We cannot change to SPMD mode if the runtime functions aren't availible. + if (!OMPInfoCache.runtimeFnsAvailable( + {OMPRTL___kmpc_get_hardware_thread_id_in_block, + OMPRTL___kmpc_barrier_simple_spmd})) + return false; + if (!SPMDCompatibilityTracker.isAssumed()) { for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) { if (!NonCompatibleI) @@ -4019,6 +4049,13 @@ if (!ReachedKnownParallelRegions.isValidState()) return ChangeStatus::UNCHANGED; + auto &OMPInfoCache = static_cast(A.getInfoCache()); + if (!OMPInfoCache.runtimeFnsAvailable( + {OMPRTL___kmpc_get_hardware_num_threads_in_block, + OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic, + OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel})) + return ChangeStatus::UNCHANGED; + const int InitModeArgNo = 1; const int InitUseStateMachineArgNo = 2; @@ -4165,7 +4202,6 @@ BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB); Module &M = *Kernel->getParent(); - auto &OMPInfoCache = static_cast(A.getInfoCache()); FunctionCallee BlockHwSizeFn = OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( M, OMPRTL___kmpc_get_hardware_num_threads_in_block); @@ -5341,7 +5377,10 @@ BumpPtrAllocator Allocator; CallGraphUpdater CGUpdater; - OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels); + bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink || + LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink; + OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels, + PostLink); unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? SetFixpointIterations : 32; @@ -5415,9 +5454,11 @@ CallGraphUpdater CGUpdater; CGUpdater.initialize(CG, C, AM, UR); + bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink || + LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink; SetVector Functions(SCC.begin(), SCC.end()); OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, - /*CGSCC*/ &Functions, Kernels); + /*CGSCC*/ &Functions, Kernels, PostLink); unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? SetFixpointIterations : 32; diff --git a/llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll b/llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll --- a/llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll +++ b/llvm/test/Transforms/OpenMP/custom_state_machines_pre_lto.ll @@ -2,7 +2,9 @@ ; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=AMDGPU ; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=NVPTX ; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -openmp-opt-disable-state-machine-rewrite -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=AMDGPU +; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=AMDGPU ; RUN: opt --mtriple=nvptx64-- -openmp-opt-disable-state-machine-rewrite -S -passes=openmp-opt < %s | FileCheck %s --check-prefix=NVPTX +; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=NVPTX ;; void p0(void); ;; void p1(void); diff --git a/llvm/test/Transforms/OpenMP/spmdization.ll b/llvm/test/Transforms/OpenMP/spmdization.ll --- a/llvm/test/Transforms/OpenMP/spmdization.ll +++ b/llvm/test/Transforms/OpenMP/spmdization.ll @@ -2,7 +2,9 @@ ; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt < %s | FileCheck %s --check-prefixes=AMDGPU ; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt < %s | FileCheck %s --check-prefixes=NVPTX ; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt -openmp-opt-disable-spmdization < %s | FileCheck %s --check-prefix=AMDGPU-DISABLED +; RUN: opt --mtriple=amdgcn-amd-amdhsa --data-layout=A5 -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=AMDGPU-DISABLED ; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt -openmp-opt-disable-spmdization < %s | FileCheck %s --check-prefix=NVPTX-DISABLED +; RUN: opt --mtriple=nvptx64-- -S -passes=openmp-opt-postlink < %s | FileCheck %s --check-prefix=NVPTX-DISABLED ;; void unknown(void); ;; void spmd_amenable(void) __attribute__((assume("ompx_spmd_amenable")));