diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -513,6 +513,9 @@ /// State to track what kernel entries can reach the associated function. BooleanStateWithPtrSetVector ReachingKernelEntries; + /// State to track what parallel levels the associated function can be at. + BooleanStateWithSetVector ParallelLevels; + /// Abstract State interface ///{ @@ -2697,6 +2700,8 @@ // Add itself to the reaching kernel and set IsKernelEntry. ReachingKernelEntries.insert(Fn); IsKernelEntry = true; + // Kernel entry is at level 0, which means not in parallel region. + ParallelLevels.insert(0); OMPInformationCache::RuntimeFunctionInfo &InitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; @@ -3210,8 +3215,10 @@ CheckRWInst, *this, UsedAssumedInformationInCheckRWInst)) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); - if (!IsKernelEntry) + if (!IsKernelEntry) { updateReachingKernelEntries(A); + updateParallelLevels(A); + } // Callback to check a call instruction. bool AllSPMDStatesWereFixed = true; @@ -3267,6 +3274,46 @@ AllCallSitesKnown)) ReachingKernelEntries.indicatePessimisticFixpoint(); } + + /// Update info regarding parallel levels. + void updateParallelLevels(Attributor &A) { + auto &OMPInfoCache = static_cast(A.getInfoCache()); + OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51]; + + auto PredCallSite = [&](AbstractCallSite ACS) { + Function *Caller = ACS.getInstruction()->getFunction(); + + assert(Caller && "Caller is nullptr"); + + auto &CAA = + A.getOrCreateAAFor(IRPosition::function(*Caller)); + if (CAA.ParallelLevels.isValidState()) { + if (Caller != Parallel51RFI.Declaration) + ParallelLevels ^= CAA.ParallelLevels; + else + for (uint8_t L : CAA.ParallelLevels) + ParallelLevels.insert(L + 1); + + if (ParallelLevels.size() > 1) + ParallelLevels.indicatePessimisticFixpoint(); + + return true; + } + + // We lost track of the caller of the associated function, any kernel + // could reach now. + ParallelLevels.indicatePessimisticFixpoint(); + + return true; + }; + + bool AllCallSitesKnown; + if (!A.checkForAllCallSites(PredCallSite, *this, + true /* RequireAllCallSites */, + AllCallSitesKnown)) + ParallelLevels.indicatePessimisticFixpoint(); + } }; /// The call site kernel info abstract attribute, basically, what can we say @@ -3499,6 +3546,9 @@ case OMPRTL___kmpc_is_spmd_exec_mode: Changed |= foldIsSPMDExecMode(A); break; + case OMPRTL___kmpc_parallel_level: + Changed |= foldParallelLevel(A); + break; default: llvm_unreachable("Unhandled OpenMP runtime function!"); } @@ -3589,6 +3639,34 @@ : ChangeStatus::CHANGED; } + /// Fold __kmpc_parallel_level into a constant if possible. + ChangeStatus foldParallelLevel(Attributor &A) { + Optional SimplifiedValueBefore = SimplifiedValue; + + auto &CallerKernelInfoAA = A.getAAFor( + *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); + + if (!CallerKernelInfoAA.ParallelLevels.isValidState()) + return indicatePessimisticFixpoint(); + + // If we cannot deduct parallel region of the caller, we set SimplifiedValue + // to nullptr to indicate the associated function call cannot be folded for + // now. + if (CallerKernelInfoAA.ParallelLevels.empty()) + SimplifiedValue = nullptr; + else { + assert(CallerKernelInfoAA.ParallelLevels.size() == 1 && + "ParallelLevels has more than one elements but still in a valid " + "state."); + const uint8_t Level = CallerKernelInfoAA.ParallelLevels[0]; + auto &Ctx = getAnchorValue().getContext(); + SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), Level); + } + + return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } + /// An optional value the associated value is assumed to fold to. That is, we /// assume the associated value (which is a call) can be replaced by this /// simplified value. @@ -3626,6 +3704,19 @@ /* UpdateAfterInit */ false); return false; }); + + auto &ParallelLevelRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_level]; + ParallelLevelRFI.foreachUse(SCC, [&](Use &U, Function &) { + CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &ParallelLevelRFI); + if (!CI) + return false; + A.getOrCreateAAFor( + IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr, + DepClassTy::NONE, /* ForceUpdate */ false, + /* UpdateAfterInit */ false); + + return false; + }); } // Create CallSite AA for all Getters. diff --git a/llvm/test/Transforms/OpenMP/parallel_level_fold.ll b/llvm/test/Transforms/OpenMP/parallel_level_fold.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/OpenMP/parallel_level_fold.ll @@ -0,0 +1,62 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --check-globals +; RUN: opt -S -passes=openmp-opt < %s | FileCheck %s +target triple = "nvptx64" + +%struct.ident_t = type { i32, i32, i32, i32, i8* } + +@no_parallel_exec_mode = weak constant i8 0 +@G = external global i8 +@llvm.compiler.used = appending global [1 x i8*] [i8* @no_parallel_exec_mode], section "llvm.metadata" + +;. +; CHECK: @[[NO_PARALLEL_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0 +; CHECK: @[[G:[a-zA-Z0-9_$"\\.-]+]] = external global i8 +; CHECK: @[[LLVM_COMPILER_USED:[a-zA-Z0-9_$"\\.-]+]] = appending global [1 x i8*] [i8* @no_parallel_exec_mode], section "llvm.metadata" +;. +define weak void @no_parallel() { +; CHECK-LABEL: define {{[^@]+}}@no_parallel() { +; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false) +; CHECK-NEXT: call void @foo() +; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false) +; CHECK-NEXT: ret void +; + %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false) + call void @foo() + call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false) + ret void +} + +define internal void @foo() { +; CHECK-LABEL: define {{[^@]+}}@foo() { +; CHECK-NEXT: call void @bar() +; CHECK-NEXT: ret void +; + call void @bar() + ret void +} + +define internal void @bar() { +; CHECK-LABEL: define {{[^@]+}}@bar() { +; CHECK-NEXT: store i8 0, i8* @G, align 1 +; CHECK-NEXT: ret void +; + %level = call i8 @__kmpc_parallel_level() + store i8 %level, i8* @G + ret void +} + +declare i8 @__kmpc_parallel_level() +declare i32 @__kmpc_target_init(%struct.ident_t*, i1 zeroext, i1 zeroext, i1 zeroext) #1 +declare void @__kmpc_target_deinit(%struct.ident_t* nocapture readnone, i1 zeroext, i1 zeroext) #1 + +!llvm.module.flags = !{!0, !1} +!nvvm.annotations = !{!2} + +!0 = !{i32 7, !"openmp", i32 50} +!1 = !{i32 7, !"openmp-device", i32 50} +!2 = !{void ()* @no_parallel, !"kernel", i32 1} +;. +; CHECK: [[META0:![0-9]+]] = !{i32 7, !"openmp", i32 50} +; CHECK: [[META1:![0-9]+]] = !{i32 7, !"openmp-device", i32 50} +; CHECK: [[META2:![0-9]+]] = !{void ()* @no_parallel, !"kernel", i32 1} +;.