diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -519,6 +519,11 @@ /// State to track what kernel entries can reach the associated function. BooleanStateWithPtrSetVector ReachingKernelEntries; + /// State to indicate if we can track parallel level of the associated + /// function. We will give up tracking if we encounter unknown caller or the + /// caller is __kmpc_parallel_51. + BooleanStateWithSetVector ParallelLevels; + /// Abstract State interface ///{ @@ -3329,8 +3334,10 @@ CheckRWInst, *this, UsedAssumedInformationInCheckRWInst)) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - if (!IsKernelEntry) + if (!IsKernelEntry) { updateReachingKernelEntries(A); + updateParallelLevels(A); + } // Callback to check a call instruction. bool AllSPMDStatesWereFixed = true; @@ -3386,6 +3393,49 @@ AllCallSitesKnown)) ReachingKernelEntries.indicatePessimisticFixpoint(); } + + /// Update info regarding parallel levels. + void updateParallelLevels(Attributor &A) { + auto &OMPInfoCache = static_cast(A.getInfoCache()); + OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; + + auto PredCallSite = [&](AbstractCallSite ACS) { + Function *Caller = ACS.getInstruction()->getFunction(); + + assert(Caller && "Caller is nullptr"); + + auto &CAA = + A.getOrCreateAAFor(IRPosition::function(*Caller)); + if (CAA.ParallelLevels.isValidState()) { + // Any function that is called by `__kmpc_parallel_51` will not be + // folded as the parallel level in the function is updated. In order to + // get it right, all the analysis would depend on the implentation. That + // said, if in the future any change to the implementation, the analysis + // could be wrong. As a consequence, we are just conservative here. + if (Caller == Parallel51RFI.Declaration) { + ParallelLevels.indicatePessimisticFixpoint(); + return true; + } + + ParallelLevels ^= CAA.ParallelLevels; + + return true; + } + + // We lost track of the caller of the associated function, any kernel + // could reach now. + ParallelLevels.indicatePessimisticFixpoint(); + + return true; + }; + + bool AllCallSitesKnown = true; + if (!A.checkForAllCallSites(PredCallSite, *this, + true /* RequireAllCallSites */, + AllCallSitesKnown)) + ParallelLevels.indicatePessimisticFixpoint(); + } }; /// The call site kernel info abstract attribute, basically, what can we say @@ -3668,6 +3718,9 @@ case OMPRTL___kmpc_is_generic_main_thread_id: Changed |= foldIsGenericMainThread(A); break; + case OMPRTL___kmpc_parallel_level: + Changed |= foldParallelLevel(A); + break; default: llvm_unreachable("Unhandled OpenMP runtime function!"); } @@ -3782,6 +3835,68 @@ : ChangeStatus::CHANGED; } + /// Fold __kmpc_parallel_level into a constant if possible. + ChangeStatus foldParallelLevel(Attributor &A) { + Optional SimplifiedValueBefore = SimplifiedValue; + + auto &CallerKernelInfoAA = A.getAAFor( + *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); + + if (!CallerKernelInfoAA.ParallelLevels.isValidState()) + return indicatePessimisticFixpoint(); + + if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState()) + return indicatePessimisticFixpoint(); + + if (CallerKernelInfoAA.ReachingKernelEntries.empty()) { + assert(!SimplifiedValue.hasValue() && + "SimplifiedValue should keep none at this point"); + return ChangeStatus::UNCHANGED; + } + + unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0; + unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0; + for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { + auto &AA = A.getAAFor(*this, IRPosition::function(*K), + DepClassTy::REQUIRED); + if (!AA.SPMDCompatibilityTracker.isValidState()) + return indicatePessimisticFixpoint(); + + if (AA.SPMDCompatibilityTracker.isAssumed()) { + if (AA.SPMDCompatibilityTracker.isAtFixpoint()) + ++KnownSPMDCount; + else + ++AssumedSPMDCount; + } else { + if (AA.SPMDCompatibilityTracker.isAtFixpoint()) + ++KnownNonSPMDCount; + else + ++AssumedNonSPMDCount; + } + } + + if ((AssumedSPMDCount + KnownSPMDCount) && + (AssumedNonSPMDCount + KnownNonSPMDCount)) + return indicatePessimisticFixpoint(); + + auto &Ctx = getAnchorValue().getContext(); + // If the caller can only be reached by SPMD kernel entries, the parallel + // level is 1. Similarly, if the caller can only be reached by non-SPMD + // kernel entries, it is 0. + if (AssumedSPMDCount || KnownSPMDCount) { + assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 && + "Expected only SPMD kernels!"); + SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1); + } else { + assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 && + "Expected only non-SPMD kernels!"); + SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0); + } + + return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + /// An optional value the associated value is assumed to fold to. That is, we /// assume the associated value (which is a call) can be replaced by this /// simplified value. @@ -3832,6 +3947,19 @@ /* UpdateAfterInit */ false); return false; }); + + auto &ParallelLevelRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_level]; + ParallelLevelRFI.foreachUse(SCC, [&](Use &U, Function &) { + CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &ParallelLevelRFI); + if (!CI) + return false; + A.getOrCreateAAFor( + IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr, + DepClassTy::NONE, /* ForceUpdate */ false, + /* UpdateAfterInit */ false); + + return false; + }); } // Create CallSite AA for all Getters. diff --git a/llvm/test/Transforms/OpenMP/parallel_level_fold.ll b/llvm/test/Transforms/OpenMP/parallel_level_fold.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/OpenMP/parallel_level_fold.ll @@ -0,0 +1,150 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --check-globals +; RUN: opt -S -passes=openmp-opt < %s | FileCheck %s +target triple = "nvptx64" + +%struct.ident_t = type { i32, i32, i32, i32, i8* } + +@no_spmd_exec_mode = weak constant i8 1 +@spmd_exec_mode = weak constant i8 0 +@parallel_exec_mode = weak constant i8 0 +@G = external global i8 +@llvm.compiler.used = appending global [3 x i8*] [i8* @no_spmd_exec_mode, i8* @spmd_exec_mode, i8* @parallel_exec_mode], section "llvm.metadata" + +;. +; CHECK: @[[NO_SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 1 +; CHECK: @[[SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0 +; CHECK: @[[PARALLEL_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0 +; CHECK: @[[G:[a-zA-Z0-9_$"\\.-]+]] = external global i8 +; CHECK: @[[LLVM_COMPILER_USED:[a-zA-Z0-9_$"\\.-]+]] = appending global [3 x i8*] [i8* @no_spmd_exec_mode, i8* @spmd_exec_mode, i8* @parallel_exec_mode], section "llvm.metadata" +;. +define weak void @none_spmd() { +; CHECK-LABEL: define {{[^@]+}}@none_spmd() { +; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false) +; CHECK-NEXT: call void @none_spmd_helper() +; CHECK-NEXT: call void @mixed_helper() +; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false) +; CHECK-NEXT: ret void +; + %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false) + call void @none_spmd_helper() + call void @mixed_helper() + call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false) + ret void +} + +define weak void @spmd() { +; CHECK-LABEL: define {{[^@]+}}@spmd() { +; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false) +; CHECK-NEXT: call void @spmd_helper() +; CHECK-NEXT: call void @mixed_helper() +; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false) +; CHECK-NEXT: ret void +; + %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false) + call void @spmd_helper() + call void @mixed_helper() + call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false) + ret void +} + +define weak void @parallel() { +; CHECK-LABEL: define {{[^@]+}}@parallel() { +; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* align 536870912 null, i1 true, i1 false, i1 false) +; CHECK-NEXT: call void @spmd_helper() +; CHECK-NEXT: call void @__kmpc_parallel_51(%struct.ident_t* noalias noundef align 536870912 null, i32 noundef 0, i32 noundef 0, i32 noundef 0, i32 noundef 0, i8* noalias noundef align 536870912 null, i8* noalias noundef align 536870912 null, i8** noalias noundef align 536870912 null, i64 noundef 0) +; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false) +; CHECK-NEXT: ret void +; + %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false) + call void @spmd_helper() + call void @__kmpc_parallel_51(%struct.ident_t* null, i32 0, i32 0, i32 0, i32 0, i8* null, i8* null, i8** null, i64 0) + call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false) + ret void +} + +define internal void @mixed_helper() { +; CHECK-LABEL: define {{[^@]+}}@mixed_helper() { +; CHECK-NEXT: [[LEVEL:%.*]] = call i8 @__kmpc_parallel_level() +; CHECK-NEXT: store i8 [[LEVEL]], i8* @G, align 1 +; CHECK-NEXT: ret void +; + %level = call i8 @__kmpc_parallel_level() + store i8 %level, i8* @G + ret void +} + +define internal void @none_spmd_helper() { +; CHECK-LABEL: define {{[^@]+}}@none_spmd_helper() { +; CHECK-NEXT: [[LEVEL12:%.*]] = call i8 @__kmpc_parallel_level() +; CHECK-NEXT: [[C:%.*]] = icmp eq i8 [[LEVEL12]], 0 +; CHECK-NEXT: br i1 [[C]], label [[T:%.*]], label [[F:%.*]] +; CHECK: t: +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: ret void +; CHECK: f: +; CHECK-NEXT: call void @bar() +; CHECK-NEXT: ret void +; + %level12 = call i8 @__kmpc_parallel_level() + %c = icmp eq i8 %level12, 0 + br i1 %c, label %t, label %f +t: + call void @foo() + ret void +f: + call void @bar() + ret void +} + +define internal void @spmd_helper() { +; CHECK-LABEL: define {{[^@]+}}@spmd_helper() { +; CHECK-NEXT: store i8 1, i8* @G, align 1 +; CHECK-NEXT: ret void +; + %level = call i8 @__kmpc_parallel_level() + store i8 %level, i8* @G + ret void +} + +define internal void @__kmpc_parallel_51(%struct.ident_t*, i32, i32, i32, i32, i8*, i8*, i8**, i64) { +; CHECK-LABEL: define {{[^@]+}}@__kmpc_parallel_51 +; CHECK-SAME: (%struct.ident_t* noalias nocapture nofree readnone align 536870912 [[TMP0:%.*]], i32 [[TMP1:%.*]], i32 [[TMP2:%.*]], i32 [[TMP3:%.*]], i32 [[TMP4:%.*]], i8* noalias nocapture nofree readnone align 536870912 [[TMP5:%.*]], i8* noalias nocapture nofree readnone align 536870912 [[TMP6:%.*]], i8** noalias nocapture nofree readnone align 536870912 [[TMP7:%.*]], i64 [[TMP8:%.*]]) { +; CHECK-NEXT: call void @parallel_helper() +; CHECK-NEXT: ret void +; + call void @parallel_helper() + ret void +} + +define internal void @parallel_helper() { +; CHECK-LABEL: define {{[^@]+}}@parallel_helper() { +; CHECK-NEXT: [[LEVEL:%.*]] = call i8 @__kmpc_parallel_level() +; CHECK-NEXT: store i8 [[LEVEL]], i8* @G, align 1 +; CHECK-NEXT: ret void +; + %level = call i8 @__kmpc_parallel_level() + store i8 %level, i8* @G + ret void +} + +declare void @foo() +declare void @bar() +declare i8 @__kmpc_parallel_level() +declare i32 @__kmpc_target_init(%struct.ident_t*, i1 zeroext, i1 zeroext, i1 zeroext) #1 +declare void @__kmpc_target_deinit(%struct.ident_t* nocapture readnone, i1 zeroext, i1 zeroext) #1 + +!llvm.module.flags = !{!0, !1} +!nvvm.annotations = !{!2, !3, !4} + +!0 = !{i32 7, !"openmp", i32 50} +!1 = !{i32 7, !"openmp-device", i32 50} +!2 = !{void ()* @none_spmd, !"kernel", i32 1} +!3 = !{void ()* @spmd, !"kernel", i32 1} +!4 = !{void ()* @parallel, !"kernel", i32 1} +;. +; CHECK: [[META0:![0-9]+]] = !{i32 7, !"openmp", i32 50} +; CHECK: [[META1:![0-9]+]] = !{i32 7, !"openmp-device", i32 50} +; CHECK: [[META2:![0-9]+]] = !{void ()* @none_spmd, !"kernel", i32 1} +; CHECK: [[META3:![0-9]+]] = !{void ()* @spmd, !"kernel", i32 1} +; CHECK: [[META4:![0-9]+]] = !{void ()* @parallel, !"kernel", i32 1} +;. diff --git a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu --- a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu +++ b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu @@ -239,7 +239,7 @@ currTaskDescr->RestoreLoopData(); } -EXTERN uint8_t __kmpc_parallel_level() { +NOINLINE EXTERN uint8_t __kmpc_parallel_level() { return parallelLevel[GetWarpId()] & (OMP_ACTIVE_PARALLEL_LEVEL - 1); } @@ -282,11 +282,11 @@ // parallel interface //////////////////////////////////////////////////////////////////////////////// -EXTERN void __kmpc_parallel_51(kmp_Ident *ident, kmp_int32 global_tid, - kmp_int32 if_expr, kmp_int32 num_threads, - int proc_bind, void *fn, void *wrapper_fn, - void **args, size_t nargs) { - +NOINLINE EXTERN void __kmpc_parallel_51(kmp_Ident *ident, kmp_int32 global_tid, + kmp_int32 if_expr, + kmp_int32 num_threads, int proc_bind, + void *fn, void *wrapper_fn, void **args, + size_t nargs) { // Handle the serialized case first, same for SPMD/non-SPMD except that in // SPMD mode we already incremented the parallel level counter, account for // that. diff --git a/openmp/libomptarget/deviceRTLs/interface.h b/openmp/libomptarget/deviceRTLs/interface.h --- a/openmp/libomptarget/deviceRTLs/interface.h +++ b/openmp/libomptarget/deviceRTLs/interface.h @@ -222,7 +222,7 @@ int32_t num_threads); EXTERN void __kmpc_serialized_parallel(kmp_Ident *loc, uint32_t global_tid); EXTERN void __kmpc_end_serialized_parallel(kmp_Ident *loc, uint32_t global_tid); -EXTERN uint8_t __kmpc_parallel_level(); +NOINLINE EXTERN uint8_t __kmpc_parallel_level(); // proc bind EXTERN void __kmpc_push_proc_bind(kmp_Ident *loc, uint32_t global_tid, @@ -441,10 +441,11 @@ /// \param wrapper_fn The worker wrapper function of fn. /// \param args The pointer array of arguments to fn. /// \param nargs The number of arguments to fn. -EXTERN void __kmpc_parallel_51(ident_t *ident, kmp_int32 global_tid, - kmp_int32 if_expr, kmp_int32 num_threads, - int proc_bind, void *fn, void *wrapper_fn, - void **args, size_t nargs); +NOINLINE EXTERN void __kmpc_parallel_51(ident_t *ident, kmp_int32 global_tid, + kmp_int32 if_expr, + kmp_int32 num_threads, int proc_bind, + void *fn, void *wrapper_fn, void **args, + size_t nargs); // SPMD execution mode interrogation function. EXTERN int8_t __kmpc_is_spmd_exec_mode();