diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -519,6 +519,9 @@ /// State to track what kernel entries can reach the associated function. BooleanStateWithPtrSetVector ReachingKernelEntries; + /// State to track what parallel levels the associated function can be at. + BooleanStateWithSetVector ParallelLevels; + /// Abstract State interface ///{ @@ -2832,8 +2835,19 @@ // Check if we know we are in SPMD-mode already. ConstantInt *IsSPMDArg = dyn_cast(KernelInitCB->getArgOperand(InitIsSPMDArgNo)); - if (IsSPMDArg && !IsSPMDArg->isZero()) + if (IsSPMDArg && !IsSPMDArg->isZero()) { SPMDCompatibilityTracker.indicateOptimisticFixpoint(); + // Kernel entry is at level 1 if in SPMD mode. + // NOTE: This is quite dangerous because we assume there is no direct or + // indirect function call to `__kmpc_parallel_level` before we update the + // parallel level in `__kmpc_spmd_kernel_init`. However, we have to do it + // in this way to make the fold right. Alternatively, we could only fold + // to 0 if we don't use any assumed information. + ParallelLevels.insert(1); + } else { + // Kernel entry is at level 0 if in none-SPMD mode. + ParallelLevels.insert(0); + } } /// Modify the IR based on the KernelInfoState as the fixpoint iteration is @@ -3231,8 +3245,10 @@ CheckRWInst, *this, UsedAssumedInformationInCheckRWInst)) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - if (!IsKernelEntry) + if (!IsKernelEntry) { updateReachingKernelEntries(A); + updateParallelLevels(A); + } // Callback to check a call instruction. bool AllSPMDStatesWereFixed = true; @@ -3288,6 +3304,49 @@ AllCallSitesKnown)) ReachingKernelEntries.indicatePessimisticFixpoint(); } + + /// Update info regarding parallel levels. + void updateParallelLevels(Attributor &A) { + auto &OMPInfoCache = static_cast(A.getInfoCache()); + OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; + + auto PredCallSite = [&](AbstractCallSite ACS) { + Function *Caller = ACS.getInstruction()->getFunction(); + + assert(Caller && "Caller is nullptr"); + + auto &CAA = + A.getOrCreateAAFor(IRPosition::function(*Caller)); + if (CAA.ParallelLevels.isValidState()) { + // Any function that is called by `__kmpc_parallel_51` will not be + // folded as the parallel level in the function is updated. In order to + // get it right, all the analysis would depend on the implentation. That + // said, if in the future any change to the implementation, the analysis + // could be wrong. As a consequence, we are just conservative here. + if (Caller == Parallel51RFI.Declaration) { + ParallelLevels.indicatePessimisticFixpoint(); + return true; + } + + ParallelLevels ^= CAA.ParallelLevels; + + return true; + } + + // We lost track of the caller of the associated function, any kernel + // could reach now. + ParallelLevels.indicatePessimisticFixpoint(); + + return true; + }; + + bool AllCallSitesKnown = true; + if (!A.checkForAllCallSites(PredCallSite, *this, + true /* RequireAllCallSites */, + AllCallSitesKnown)) + ParallelLevels.indicatePessimisticFixpoint(); + } }; /// The call site kernel info abstract attribute, basically, what can we say @@ -3523,6 +3582,9 @@ case OMPRTL___kmpc_is_generic_main_thread_id: Changed |= foldIsGenericMainThread(A); break; + case OMPRTL___kmpc_parallel_level: + Changed |= foldParallelLevel(A); + break; default: llvm_unreachable("Unhandled OpenMP runtime function!"); } @@ -3639,6 +3701,57 @@ : ChangeStatus::CHANGED; } + /// Fold __kmpc_parallel_level into a constant if possible. + ChangeStatus foldParallelLevel(Attributor &A) { + Optional SimplifiedValueBefore = SimplifiedValue; + + auto &CallerKernelInfoAA = A.getAAFor( + *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); + + if (!CallerKernelInfoAA.ParallelLevels.isValidState()) + return indicatePessimisticFixpoint(); + + if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState()) + return indicatePessimisticFixpoint(); + + // If the parallel region of the caller can be at multiple levels, we cannot + // fold it. + if (CallerKernelInfoAA.ParallelLevels.size() > 1) + return indicatePessimisticFixpoint(); + + unsigned SPMDCount = 0; + for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { + auto &AA = A.getAAFor(*this, IRPosition::function(*K), + DepClassTy::REQUIRED); + if (!AA.SPMDCompatibilityTracker.isValidState()) + return indicatePessimisticFixpoint(); + + if (AA.SPMDCompatibilityTracker.isAssumed()) + ++SPMDCount; + } + + // If we cannot deduct parallel region of the caller, we set SimplifiedValue + // to nullptr to indicate the associated function call cannot be folded + // **for now**. + if (CallerKernelInfoAA.ParallelLevels.empty()) + assert(!SimplifiedValue.hasValue() && + "SimplifiedValue should keep none at this point"); + else { + const uint8_t Level = CallerKernelInfoAA.ParallelLevels[0]; + // If the caller can be reached by a SPMD kernel entry, the parallel level + // is 1. As a result, if the detected level is not 1, the caller can be at + // multiple levels, and we cannot fold it. + if (Level != 1 && SPMDCount) + return indicatePessimisticFixpoint(); + + auto &Ctx = getAnchorValue().getContext(); + SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), Level); + } + + return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + /// An optional value the associated value is assumed to fold to. That is, we /// assume the associated value (which is a call) can be replaced by this /// simplified value. @@ -3689,6 +3802,19 @@ /* UpdateAfterInit */ false); return false; }); + + auto &ParallelLevelRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_level]; + ParallelLevelRFI.foreachUse(SCC, [&](Use &U, Function &) { + CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &ParallelLevelRFI); + if (!CI) + return false; + A.getOrCreateAAFor( + IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr, + DepClassTy::NONE, /* ForceUpdate */ false, + /* UpdateAfterInit */ false); + + return false; + }); } // Create CallSite AA for all Getters. diff --git a/llvm/test/Transforms/OpenMP/parallel_level_fold.ll b/llvm/test/Transforms/OpenMP/parallel_level_fold.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/OpenMP/parallel_level_fold.ll @@ -0,0 +1,135 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --check-globals +; RUN: opt -S -passes=openmp-opt < %s | FileCheck %s +target triple = "nvptx64" + +%struct.ident_t = type { i32, i32, i32, i32, i8* } + +@no_spmd_exec_mode = weak constant i8 1 +@spmd_exec_mode = weak constant i8 0 +@parallel_exec_mode = weak constant i8 0 +@G = external global i8 +@llvm.compiler.used = appending global [3 x i8*] [i8* @no_spmd_exec_mode, i8* @spmd_exec_mode, i8* @parallel_exec_mode], section "llvm.metadata" + +;. +; CHECK: @[[NO_SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 1 +; CHECK: @[[SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0 +; CHECK: @[[PARALLEL_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0 +; CHECK: @[[G:[a-zA-Z0-9_$"\\.-]+]] = external global i8 +; CHECK: @[[LLVM_COMPILER_USED:[a-zA-Z0-9_$"\\.-]+]] = appending global [3 x i8*] [i8* @no_spmd_exec_mode, i8* @spmd_exec_mode, i8* @parallel_exec_mode], section "llvm.metadata" +;. +define weak void @none_spmd() { +; CHECK-LABEL: define {{[^@]+}}@none_spmd() { +; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false) +; CHECK-NEXT: call void @none_spmd_helper() +; CHECK-NEXT: call void @mixed_helper() +; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false) +; CHECK-NEXT: ret void +; + %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false) + call void @none_spmd_helper() + call void @mixed_helper() + call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false) + ret void +} + +define weak void @spmd() { +; CHECK-LABEL: define {{[^@]+}}@spmd() { +; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false) +; CHECK-NEXT: call void @spmd_helper() +; CHECK-NEXT: call void @mixed_helper() +; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false) +; CHECK-NEXT: ret void +; + %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false) + call void @spmd_helper() + call void @mixed_helper() + call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false) + ret void +} + +define weak void @parallel() { +; CHECK-LABEL: define {{[^@]+}}@parallel() { +; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* align 536870912 null, i1 true, i1 false, i1 false) +; CHECK-NEXT: call void @spmd_helper() +; CHECK-NEXT: call void @__kmpc_parallel_51(%struct.ident_t* noalias noundef align 536870912 null, i32 noundef 0, i32 noundef 0, i32 noundef 0, i32 noundef 0, i8* noalias noundef align 536870912 null, i8* noalias noundef align 536870912 null, i8** noalias noundef align 536870912 null, i64 noundef 0) +; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false) +; CHECK-NEXT: ret void +; + %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false) + call void @spmd_helper() + call void @__kmpc_parallel_51(%struct.ident_t* null, i32 0, i32 0, i32 0, i32 0, i8* null, i8* null, i8** null, i64 0) + call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false) + ret void +} + +define internal void @mixed_helper() { +; CHECK-LABEL: define {{[^@]+}}@mixed_helper() { +; CHECK-NEXT: [[LEVEL:%.*]] = call i8 @__kmpc_parallel_level() +; CHECK-NEXT: store i8 [[LEVEL]], i8* @G, align 1 +; CHECK-NEXT: ret void +; + %level = call i8 @__kmpc_parallel_level() + store i8 %level, i8* @G + ret void +} + +define internal void @none_spmd_helper() { +; CHECK-LABEL: define {{[^@]+}}@none_spmd_helper() { +; CHECK-NEXT: store i8 0, i8* @G, align 1 +; CHECK-NEXT: ret void +; + %level = call i8 @__kmpc_parallel_level() + store i8 %level, i8* @G + ret void +} + +define internal void @spmd_helper() { +; CHECK-LABEL: define {{[^@]+}}@spmd_helper() { +; CHECK-NEXT: store i8 1, i8* @G, align 1 +; CHECK-NEXT: ret void +; + %level = call i8 @__kmpc_parallel_level() + store i8 %level, i8* @G + ret void +} + +define internal void @__kmpc_parallel_51(%struct.ident_t*, i32, i32, i32, i32, i8*, i8*, i8**, i64) { +; CHECK-LABEL: define {{[^@]+}}@__kmpc_parallel_51 +; CHECK-SAME: (%struct.ident_t* noalias nocapture nofree readnone align 536870912 [[TMP0:%.*]], i32 [[TMP1:%.*]], i32 [[TMP2:%.*]], i32 [[TMP3:%.*]], i32 [[TMP4:%.*]], i8* noalias nocapture nofree readnone align 536870912 [[TMP5:%.*]], i8* noalias nocapture nofree readnone align 536870912 [[TMP6:%.*]], i8** noalias nocapture nofree readnone align 536870912 [[TMP7:%.*]], i64 [[TMP8:%.*]]) { +; CHECK-NEXT: call void @parallel_helper() +; CHECK-NEXT: ret void +; + call void @parallel_helper() + ret void +} + +define internal void @parallel_helper() { +; CHECK-LABEL: define {{[^@]+}}@parallel_helper() { +; CHECK-NEXT: [[LEVEL:%.*]] = call i8 @__kmpc_parallel_level() +; CHECK-NEXT: store i8 [[LEVEL]], i8* @G, align 1 +; CHECK-NEXT: ret void +; + %level = call i8 @__kmpc_parallel_level() + store i8 %level, i8* @G + ret void +} + +declare i8 @__kmpc_parallel_level() +declare i32 @__kmpc_target_init(%struct.ident_t*, i1 zeroext, i1 zeroext, i1 zeroext) #1 +declare void @__kmpc_target_deinit(%struct.ident_t* nocapture readnone, i1 zeroext, i1 zeroext) #1 + +!llvm.module.flags = !{!0, !1} +!nvvm.annotations = !{!2, !3, !4} + +!0 = !{i32 7, !"openmp", i32 50} +!1 = !{i32 7, !"openmp-device", i32 50} +!2 = !{void ()* @none_spmd, !"kernel", i32 1} +!3 = !{void ()* @spmd, !"kernel", i32 1} +!4 = !{void ()* @parallel, !"kernel", i32 1} +;. +; CHECK: [[META0:![0-9]+]] = !{i32 7, !"openmp", i32 50} +; CHECK: [[META1:![0-9]+]] = !{i32 7, !"openmp-device", i32 50} +; CHECK: [[META2:![0-9]+]] = !{void ()* @none_spmd, !"kernel", i32 1} +; CHECK: [[META3:![0-9]+]] = !{void ()* @spmd, !"kernel", i32 1} +; CHECK: [[META4:![0-9]+]] = !{void ()* @parallel, !"kernel", i32 1} +;. diff --git a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu --- a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu +++ b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu @@ -282,10 +282,10 @@ // parallel interface //////////////////////////////////////////////////////////////////////////////// -EXTERN void __kmpc_parallel_51(kmp_Ident *ident, kmp_int32 global_tid, - kmp_int32 if_expr, kmp_int32 num_threads, - int proc_bind, void *fn, void *wrapper_fn, - void **args, size_t nargs) { +EXTERN __attribute__((noinline)) void +__kmpc_parallel_51(kmp_Ident *ident, kmp_int32 global_tid, kmp_int32 if_expr, + kmp_int32 num_threads, int proc_bind, void *fn, + void *wrapper_fn, void **args, size_t nargs) { // Handle the serialized case first, same for SPMD/non-SPMD except that in // SPMD mode we already incremented the parallel level counter, account for diff --git a/openmp/libomptarget/deviceRTLs/interface.h b/openmp/libomptarget/deviceRTLs/interface.h --- a/openmp/libomptarget/deviceRTLs/interface.h +++ b/openmp/libomptarget/deviceRTLs/interface.h @@ -441,10 +441,10 @@ /// \param wrapper_fn The worker wrapper function of fn. /// \param args The pointer array of arguments to fn. /// \param nargs The number of arguments to fn. -EXTERN void __kmpc_parallel_51(ident_t *ident, kmp_int32 global_tid, - kmp_int32 if_expr, kmp_int32 num_threads, - int proc_bind, void *fn, void *wrapper_fn, - void **args, size_t nargs); +EXTERN __attribute__((noinline)) void +__kmpc_parallel_51(ident_t *ident, kmp_int32 global_tid, kmp_int32 if_expr, + kmp_int32 num_threads, int proc_bind, void *fn, + void *wrapper_fn, void **args, size_t nargs); // SPMD execution mode interrogation function. EXTERN int8_t __kmpc_is_spmd_exec_mode();