diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp --- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp @@ -1049,7 +1049,8 @@ void CGOpenMPRuntimeGPU::emitKernelInit(CodeGenFunction &CGF, EntryFunctionState &EST, bool IsSPMD) { CGBuilderTy &Bld = CGF.Builder; - Bld.restoreIP(OMPBuilder.createTargetInit(Bld, IsSPMD, requiresFullRuntime())); + Bld.restoreIP(OMPBuilder.createTargetInit( + Bld, IsSPMD, /* IsSPMDGuarded */ false, requiresFullRuntime())); IsInTargetMasterThreadRegion = IsSPMD; if (!IsSPMD) emitGenericVarsProlog(CGF, EST.Loc); diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -791,7 +791,8 @@ /// \param Loc The insert and source location description. /// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not. /// \param RequiresFullRuntime Indicate if a full device runtime is necessary. - InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime); + InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD, + bool IsSPMDGuarded, bool RequiresFullRuntime); /// Create a runtime call for kmpc_target_deinit /// diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -412,7 +412,7 @@ /* Int */ Int32, /* kmp_task_t */ VoidPtr) /// OpenMP Device runtime functions -__OMP_RTL(__kmpc_target_init, false, Int32, IdentPtr, Int1, Int1, Int1) +__OMP_RTL(__kmpc_target_init, false, Int32, IdentPtr, Int1, Int1, Int1, Int1) __OMP_RTL(__kmpc_target_deinit, false, Void, IdentPtr, Int1, Int1) __OMP_RTL(__kmpc_kernel_prepare_parallel, false, Void, VoidPtr) __OMP_RTL(__kmpc_parallel_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32, @@ -438,6 +438,7 @@ __OMP_RTL(__kmpc_get_shared_variables, false, Void, VoidPtrPtrPtr) __OMP_RTL(__kmpc_parallel_level, false, Int8, ) __OMP_RTL(__kmpc_is_spmd_exec_mode, false, Int8, ) +__OMP_RTL(__kmpc_is_spmd_guarded_exec_mode, false, Int8, ) __OMP_RTL(__kmpc_barrier_simple_spmd, false, Void, IdentPtr, Int32) __OMP_RTL(__kmpc_warp_active_thread_mask, false, LanemaskTy,) diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -2193,13 +2193,17 @@ } OpenMPIRBuilder::InsertPointTy -OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime) { +OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD, + bool IsSPMDGuarded, + bool RequiresFullRuntime) { if (!updateToLocation(Loc)) return Loc.IP; Constant *SrcLocStr = getOrCreateSrcLocStr(Loc); Value *Ident = getOrCreateIdent(SrcLocStr); ConstantInt *IsSPMDVal = ConstantInt::getBool(Int32->getContext(), IsSPMD); + ConstantInt *IsSPMDGuardedVal = + ConstantInt::getBool(Int32->getContext(), IsSPMDGuarded); ConstantInt *UseGenericStateMachine = ConstantInt::getBool(Int32->getContext(), !IsSPMD); ConstantInt *RequiresFullRuntimeVal = ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime); @@ -2208,7 +2212,8 @@ omp::RuntimeFunction::OMPRTL___kmpc_target_init); CallInst *ThreadKind = - Builder.CreateCall(Fn, {Ident, IsSPMDVal, UseGenericStateMachine, RequiresFullRuntimeVal}); + Builder.CreateCall(Fn, {Ident, IsSPMDVal, IsSPMDGuardedVal, + UseGenericStateMachine, RequiresFullRuntimeVal}); Value *ExecUserCode = Builder.CreateICmpEQ( ThreadKind, ConstantInt::get(ThreadKind->getType(), -1), "exec_user_code"); diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -39,6 +39,7 @@ #include "llvm/Transforms/IPO/Attributor.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CallGraphUpdater.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/CodeExtractor.h" using namespace llvm; @@ -503,7 +504,7 @@ /// State to track if we are in SPMD-mode, assumed or know, and why we decided /// we cannot be. If it is assumed, then RequiresFullRuntime should also be /// false. - BooleanStateWithPtrSetVector SPMDCompatibilityTracker; + BooleanStateWithPtrSetVector SPMDCompatibilityTracker; /// The __kmpc_target_init call in this kernel, if any. If we find more than /// one we abort as the kernel is malformed. @@ -2756,6 +2757,12 @@ AAKernelInfoFunction(const IRPosition &IRP, Attributor &A) : AAKernelInfo(IRP, A) {} + SmallPtrSet GuardedInstructions; + + SmallPtrSetImpl &getGuardedInstructions() { + return GuardedInstructions; + } + /// See AbstractAttribute::initialize(...). void initialize(Attributor &A) override { // This is a high-level transform that might change the constant arguments @@ -2849,6 +2856,29 @@ return Val; }; + Attributor::SimplifictionCallbackTy IsSPMDGuardedModeSimplifyCB = + [&](const IRPosition &IRP, const AbstractAttribute *AA, + bool &UsedAssumedInformation) -> Optional { + // IRP represents the "SPMDCompatibilityTracker" argument of an + // __kmpc_target_init or + // __kmpc_target_deinit call. We will answer this one with the internal + // state. + if (!SPMDCompatibilityTracker.isValidState()) + return nullptr; + if (!SPMDCompatibilityTracker.isAtFixpoint()) { + if (AA) + A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); + UsedAssumedInformation = true; + } else { + UsedAssumedInformation = false; + } + + auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(), + (SPMDCompatibilityTracker.isAssumed() && + !SPMDCompatibilityTracker.empty())); + return Val; + }; + Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB = [&](const IRPosition &IRP, const AbstractAttribute *AA, bool &UsedAssumedInformation) -> Optional { @@ -2871,9 +2901,10 @@ }; constexpr const int InitIsSPMDArgNo = 1; + constexpr const int InitIsSPMDGuardedArgNo = 2; + constexpr const int InitUseStateMachineArgNo = 3; + constexpr const int InitRequiresFullRuntimeArgNo = 4; constexpr const int DeinitIsSPMDArgNo = 1; - constexpr const int InitUseStateMachineArgNo = 2; - constexpr const int InitRequiresFullRuntimeArgNo = 3; constexpr const int DeinitRequiresFullRuntimeArgNo = 2; A.registerSimplificationCallback( IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo), @@ -2881,6 +2912,9 @@ A.registerSimplificationCallback( IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo), IsSPMDModeSimplifyCB); + A.registerSimplificationCallback( + IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDGuardedArgNo), + IsSPMDGuardedModeSimplifyCB); A.registerSimplificationCallback( IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo), IsSPMDModeSimplifyCB); @@ -2952,6 +2986,223 @@ return false; } + auto CreateGuardedRegion = [&](Instruction *RegionStartI, + Instruction *RegionEndI) { + LoopInfo *LI = nullptr; + DominatorTree *DT = nullptr; + MemorySSAUpdater *MSU = nullptr; + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + + BasicBlock *ParentBB = RegionStartI->getParent(); + Function *Fn = ParentBB->getParent(); + Module &M = *Fn->getParent(); + + // Create all the blocks and logic. + // ParentBB: + // IsSPMDGuarded = __kmpc_is_spmd_guarded_mode() + // if (IsSPMDGuarded) + // goto RegionCheckTidBB + // RegionNotguardedBB: + // + // goto RegionExitBB + // RegionCheckTidBB: + // Tid = __kmpc_hardware_thread_id() + // if (Tid != 0) + // goto RegionBarrierBB + // RegionStartBB: + // + // goto RegionEndBB + // RegionEndBB: + // + // goto RegionBarrierBB + // RegionBarrierBB: + // __kmpc_simple_barrier_spmd() + // // second barrier is omitted if lacking escaping values. + // + // __kmpc_simple_barrier_spmd() + // goto RegionExitBB + // RegionExitBB: + // + + BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(), + DT, LI, MSU, "region.guarded.end"); + BasicBlock *RegionBarrierBB = + SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI, + MSU, "region.barrier"); + BasicBlock *RegionExitBB = + SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(), + DT, LI, MSU, "region.exit"); + BasicBlock *RegionStartBB = + SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded"); + + // Create a clone that contains an non-guarded version for parallel + // execution. + ValueToValueMapTy VMap; + BasicBlock *RegionNotguardedBB = + CloneBasicBlock(RegionStartBB, VMap, ".not"); + RegionNotguardedBB->insertInto(Fn, RegionStartBB); + RegionNotguardedBB->getTerminator()->setSuccessor(0, RegionExitBB); + + assert(ParentBB->getUniqueSuccessor() == RegionStartBB && + "Expected a different CFG"); + + BasicBlock *RegionCheckTidBB = SplitBlock( + ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid"); + + // Register basic blocks with the Attributor. + A.registerManifestAddedBasicBlock(*RegionEndBB); + A.registerManifestAddedBasicBlock(*RegionBarrierBB); + A.registerManifestAddedBasicBlock(*RegionExitBB); + A.registerManifestAddedBasicBlock(*RegionStartBB); + A.registerManifestAddedBasicBlock(*RegionCheckTidBB); + A.registerManifestAddedBasicBlock(*RegionNotguardedBB); + + bool HasBroadcastValues = false; + // Find escaping outputs from the guarded region to outside users and + // broadcast their values to them. + for (Instruction &I : *RegionStartBB) { + SmallPtrSet OutsideUsers; + for (User *Usr : I.users()) { + Instruction &UsrI = *cast(Usr); + if (UsrI.getParent() != RegionStartBB) { + dbgs() << "For I " << I << " in BB " << I.getParent()->getName() + << " found outside user UsrI " << UsrI << " in BB " + << UsrI.getParent()->getName() << "\n"; + OutsideUsers.insert(&UsrI); + } + } + + if (OutsideUsers.empty()) + continue; + + HasBroadcastValues = true; + + // Emit a global variable in shared memory to store the broadcasted + // value. + auto *SharedMem = new GlobalVariable( + M, I.getType(), /* IsConstant */ false, + GlobalValue::InternalLinkage, UndefValue::get(I.getType()), + I.getName() + ".guarded.output.alloc", nullptr, + GlobalValue::NotThreadLocal, + static_cast(AddressSpace::Shared)); + + // Emit a store instruction to update the value. + new StoreInst(&I, SharedMem, RegionEndBB->getTerminator()); + + LoadInst *LoadI = new LoadInst(I.getType(), SharedMem, + I.getName() + ".guarded.output.load", + RegionBarrierBB->getTerminator()); + + PHINode *PN = PHINode::Create(I.getType(), 2, ".phi.guarded", + &*RegionExitBB->getFirstInsertionPt()); + PN->addIncoming(VMap[&I], RegionNotguardedBB); + PN->addIncoming(LoadI, RegionBarrierBB); + // Emit a load instruction and replace uses of the output value. + for (Instruction *UsrI : OutsideUsers) { + dbgs() << "PN " << *PN << " in UsrI " << *UsrI << " in BB " + << UsrI->getParent()->getName() << "\n"; + assert(UsrI->getParent() == RegionExitBB && + "Expected escaping users in exit region"); + UsrI->replaceUsesOfWith(&I, PN); + } + } + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + + // Add check for parallel level in ParentBB. + const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc(); + ParentBB->getTerminator()->eraseFromParent(); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(ParentBB, ParentBB->end()), DL); + OMPInfoCache.OMPBuilder.updateToLocation(Loc); + auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc); + Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr); + FunctionCallee HardwareTidFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_get_hardware_thread_id_in_block); + FunctionCallee IsSPMDGuardedFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_is_spmd_guarded_exec_mode); + Value *Tid = + OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {}); + Value *IsSPMDGuarded = + OMPInfoCache.OMPBuilder.Builder.CreateCall(IsSPMDGuardedFn, {}); + Value *IsSPMDGuardedCheck = + OMPInfoCache.OMPBuilder.Builder.CreateIsNull(IsSPMDGuarded); + OMPInfoCache.OMPBuilder.Builder + .CreateCondBr(IsSPMDGuardedCheck, RegionNotguardedBB, + RegionCheckTidBB) + ->setDebugLoc(DL); + + // Add check for Tid in RegionCheckTidBB + RegionCheckTidBB->getTerminator()->eraseFromParent(); + OpenMPIRBuilder::LocationDescription LocRegionCheckTid( + InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL); + OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid); + Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid); + OMPInfoCache.OMPBuilder.Builder + .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB) + ->setDebugLoc(DL); + + // First barrier for synchronization, ensures main thread has updated + // values. + FunctionCallee BarrierFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_barrier_simple_spmd); + OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy( + RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt())); + OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid}) + ->setDebugLoc(DL); + + // Second barrier ensures workers have read broadcast values. + if (HasBroadcastValues) + CallInst::Create(BarrierFn, {Ident, Tid}, "", + RegionBarrierBB->getTerminator()) + ->setDebugLoc(DL); + }; + + // SmallPtrSet GuardedBasicBlocks; + SmallVector, 4> GuardedRegions; + + for (Instruction *GuardedI : SPMDCompatibilityTracker) { + BasicBlock *BB = GuardedI->getParent(); + auto *CalleeAA = A.lookupAAFor( + IRPosition::function(*GuardedI->getFunction()), nullptr, + DepClassTy::NONE); + assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo"); + auto &CalleeAAFunction = *cast(CalleeAA); + // Continue if instruction is already guarded. + if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI)) + continue; + + Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr; + for (Instruction &I : *BB) { + // If instruction I needs to be guarded update the guarded region + // bounds. + if (SPMDCompatibilityTracker.contains(&I)) { + CalleeAAFunction.getGuardedInstructions().insert(&I); + if (GuardedRegionStart) + GuardedRegionEnd = &I; + else + GuardedRegionStart = GuardedRegionEnd = &I; + + continue; + } + + // Instruction I does not need guarding, store + // any region found and reset bounds. + if (GuardedRegionStart) { + GuardedRegions.push_back( + std::make_pair(GuardedRegionStart, GuardedRegionEnd)); + GuardedRegionStart = nullptr; + GuardedRegionEnd = nullptr; + } + } + } + + for (auto &GR : GuardedRegions) + CreateGuardedRegion(GR.first, GR.second); + // Adjust the global exec mode flag that tells the runtime what mode this // kernel is executed in. Function *Kernel = getAnchorScope(); @@ -2970,14 +3221,18 @@ // Next rewrite the init and deinit calls to indicate we use SPMD-mode now. const int InitIsSPMDArgNo = 1; + const int InitIsSPMDGuardedArgNo = 2; + const int InitUseStateMachineArgNo = 3; + const int InitRequiresFullRuntimeArgNo = 4; const int DeinitIsSPMDArgNo = 1; - const int InitUseStateMachineArgNo = 2; - const int InitRequiresFullRuntimeArgNo = 3; const int DeinitRequiresFullRuntimeArgNo = 2; auto &Ctx = getAnchorValue().getContext(); A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo), *ConstantInt::getBool(Ctx, 1)); + A.changeUseAfterManifest( + KernelInitCB->getArgOperandUse(InitIsSPMDGuardedArgNo), + *ConstantInt::getBool(Ctx, 1)); A.changeUseAfterManifest( KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *ConstantInt::getBool(Ctx, 0)); @@ -3005,7 +3260,7 @@ "Custom state machine with invalid parallel region states?"); const int InitIsSPMDArgNo = 1; - const int InitUseStateMachineArgNo = 2; + const int InitUseStateMachineArgNo = 3; // Check if the current configuration is non-SPMD and generic state machine. // If we already have SPMD mode or a custom state machine we do not need to @@ -3283,8 +3538,21 @@ if (llvm::all_of(Objects, [](const Value *Obj) { return isa(Obj); })) return true; + // Check for AAHeapToStack moved objects to avoid guarding. + auto *HS = A.lookupAAFor( + IRPosition::function(*I.getFunction()), this, DepClassTy::OPTIONAL); + if (HS) + if (llvm::all_of(Objects, [HS](const Value *Obj) { + auto *CB = dyn_cast(Obj); + if (!CB) + return false; + return HS->isAssumedHeapToStack(*CB); + })) { + return true; + } } - // For now we give up on everything but stores. + + // Insert instruction that needs guarding. SPMDCompatibilityTracker.insert(&I); return true; }; @@ -3470,7 +3738,8 @@ // We do not look into tasks right now, just give up. SPMDCompatibilityTracker.insert(&CB); ReachedUnknownParallelRegions.insert(&CB); - break; + indicatePessimisticFixpoint(); + return; case OMPRTL___kmpc_alloc_shared: case OMPRTL___kmpc_free_shared: // Return without setting a fixpoint, to be resolved in updateImpl. @@ -3479,7 +3748,8 @@ // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, // generally. SPMDCompatibilityTracker.insert(&CB); - break; + indicatePessimisticFixpoint(); + return; } // All other OpenMP runtime calls will not reach parallel regions so they // can be safely ignored for now. Since it is a known OpenMP runtime call we diff --git a/openmp/libomptarget/deviceRTLs/common/include/target.h b/openmp/libomptarget/deviceRTLs/common/include/target.h --- a/openmp/libomptarget/deviceRTLs/common/include/target.h +++ b/openmp/libomptarget/deviceRTLs/common/include/target.h @@ -72,7 +72,7 @@ /// /// \param Ident Source location identification, can be NULL. /// -int32_t __kmpc_target_init(ident_t *Ident, bool IsSPMD, +int32_t __kmpc_target_init(ident_t *Ident, bool IsSPMD, bool IsSPMDGuarded, bool UseGenericStateMachine, bool RequiresFullRuntime); diff --git a/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu b/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu --- a/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu +++ b/openmp/libomptarget/deviceRTLs/common/src/omptarget.cu @@ -82,11 +82,12 @@ omptarget_nvptx_workFn = 0; } -static void __kmpc_spmd_kernel_init(bool RequiresFullRuntime) { +static void __kmpc_spmd_kernel_init(bool IsSPMDGuarded, bool RequiresFullRuntime) { PRINT0(LD_IO, "call to __kmpc_spmd_kernel_init\n"); - setExecutionParameters(Spmd, RequiresFullRuntime ? RuntimeInitialized - : RuntimeUninitialized); + setExecutionParameters(IsSPMDGuarded ? SpmdGuarded : Spmd, + RequiresFullRuntime ? RuntimeInitialized + : RuntimeUninitialized); int threadId = __kmpc_get_hardware_thread_id_in_block(); if (threadId == 0) { usedSlotIdx = __kmpc_impl_smid() % MAX_SM; @@ -160,7 +161,12 @@ // Return true if the current target region is executed in SPMD mode. EXTERN int8_t __kmpc_is_spmd_exec_mode() { - return (execution_param & ModeMask) == Spmd; + return ((execution_param & ModeMask) == Spmd || + (execution_param & ModeMask) == SpmdGuarded); +} + +EXTERN __attribute__((used,retain,weak)) int8_t __kmpc_is_spmd_guarded_exec_mode() { + return ((execution_param & ModeMask) == SpmdGuarded); } EXTERN int8_t __kmpc_is_generic_main_thread(kmp_int32 Tid) { @@ -202,12 +208,12 @@ } EXTERN -int32_t __kmpc_target_init(ident_t *Ident, bool IsSPMD, +int32_t __kmpc_target_init(ident_t *Ident, bool IsSPMD, bool IsSPMDGuarded, bool UseGenericStateMachine, bool RequiresFullRuntime) { int TId = __kmpc_get_hardware_thread_id_in_block(); if (IsSPMD) - __kmpc_spmd_kernel_init(RequiresFullRuntime); + __kmpc_spmd_kernel_init(IsSPMDGuarded, RequiresFullRuntime); else __kmpc_generic_kernel_init(); diff --git a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu --- a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu +++ b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu @@ -300,7 +300,36 @@ } if (__kmpc_is_spmd_exec_mode()) { + // Store spmd_guarded status to check after the parallel region executes. + int is_spmd_guarded = __kmpc_is_spmd_guarded_exec_mode(); + if (is_spmd_guarded) { + // No barrier is need on entry since this will be called only from non-guarded + // SPMD execution. + + // Disable SPMD guarding for the parallel region. Runtime suport is not needed + // by construction of SPMD guarded regions, so simple assignment to Spmd is + // enough. Also, a preceding barrier is unnecessary since all threads must be + // in non-guarded context when reaching this point. + if (__kmpc_get_hardware_thread_id_in_block() == 0) + execution_param = Spmd; + + // Barrier to ensure all threads are updated to Spmd. + __kmpc_barrier_simple_spmd(ident, 0); + } + __kmp_invoke_microtask(global_tid, 0, fn, args, nargs); + + if (is_spmd_guarded) { + // Re-enable SPMD guarding. Runtime support is not needed by construction. + // Barrier to ensure all threads have finished Spmd execution before + // re-enabling guarding. + __kmpc_barrier_simple_spmd(ident, 0); + if (__kmpc_get_hardware_thread_id_in_block() == 0) + execution_param = SpmdGuarded; + + // Barrier to ensure all threads are updated to SpmdGuarded. + __kmpc_barrier_simple_spmd(ident, 0); + } return; } diff --git a/openmp/libomptarget/deviceRTLs/common/support.h b/openmp/libomptarget/deviceRTLs/common/support.h --- a/openmp/libomptarget/deviceRTLs/common/support.h +++ b/openmp/libomptarget/deviceRTLs/common/support.h @@ -22,13 +22,14 @@ enum ExecutionMode { Spmd = 0x00u, Generic = 0x01u, - ModeMask = 0x01u, + SpmdGuarded = 0x02u, + ModeMask = 0x03u, }; enum RuntimeMode { RuntimeInitialized = 0x00u, - RuntimeUninitialized = 0x02u, - RuntimeMask = 0x02u, + RuntimeUninitialized = 0x04u, + RuntimeMask = 0x04u, }; void setExecutionParameters(ExecutionMode EMode, RuntimeMode RMode); diff --git a/openmp/libomptarget/deviceRTLs/interface.h b/openmp/libomptarget/deviceRTLs/interface.h --- a/openmp/libomptarget/deviceRTLs/interface.h +++ b/openmp/libomptarget/deviceRTLs/interface.h @@ -417,8 +417,9 @@ // non standard EXTERN int32_t __kmpc_target_init(ident_t *Ident, bool IsSPMD, - bool UseGenericStateMachine, - bool RequiresFullRuntime); + bool IsSPMDGuarded, + bool UseGenericStateMachine, + bool RequiresFullRuntime); EXTERN void __kmpc_target_deinit(ident_t *Ident, bool IsSPMD, bool RequiresFullRuntime); EXTERN void __kmpc_kernel_prepare_parallel(void *WorkFn); @@ -449,6 +450,8 @@ // SPMD execution mode interrogation function. EXTERN int8_t __kmpc_is_spmd_exec_mode(); +EXTERN int8_t __kmpc_is_spmd_guarded_exec_mode(); + /// Return true if the hardware thread id \p Tid represents the OpenMP main /// thread in generic mode outside of a parallel region. EXTERN int8_t __kmpc_is_generic_main_thread(kmp_int32 Tid);