diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -486,7 +486,8 @@ BooleanStateWithPtrSetVector ReachedUnknownParallelRegions; /// State to track if we are in SPMD-mode, assumed or know, and why we decided - /// we cannot be. + /// we cannot be. If it is assumed, then RequiresFullRuntime should also be + /// false. BooleanStateWithPtrSetVector SPMDCompatibilityTracker; /// The __kmpc_target_init call in this kernel, if any. If we find more than @@ -2806,6 +2807,8 @@ constexpr const int InitIsSPMDArgNo = 1; constexpr const int DeinitIsSPMDArgNo = 1; constexpr const int InitUseStateMachineArgNo = 2; + constexpr const int InitRequiresFullRuntimeArgNo = 3; + constexpr const int DeinitRequiresFullRuntimeArgNo = 2; A.registerSimplificationCallback( IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo), StateMachineSimplifyCB); @@ -2815,6 +2818,14 @@ A.registerSimplificationCallback( IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo), IsSPMDModeSimplifyCB); + A.registerSimplificationCallback( + IRPosition::callsite_argument(*KernelInitCB, + InitRequiresFullRuntimeArgNo), + IsSPMDModeSimplifyCB); + A.registerSimplificationCallback( + IRPosition::callsite_argument(*KernelDeinitCB, + DeinitRequiresFullRuntimeArgNo), + IsSPMDModeSimplifyCB); // Check if we know we are in SPMD-mode already. ConstantInt *IsSPMDArg = @@ -2883,6 +2894,8 @@ const int InitIsSPMDArgNo = 1; const int DeinitIsSPMDArgNo = 1; const int InitUseStateMachineArgNo = 2; + const int InitRequiresFullRuntimeArgNo = 3; + const int DeinitRequiresFullRuntimeArgNo = 2; auto &Ctx = getAnchorValue().getContext(); A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo), @@ -2893,6 +2906,13 @@ A.changeUseAfterManifest( KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo), *ConstantInt::getBool(Ctx, 1)); + A.changeUseAfterManifest( + KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo), + *ConstantInt::getBool(Ctx, 0)); + A.changeUseAfterManifest( + KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo), + *ConstantInt::getBool(Ctx, 0)); + ++NumOpenMPTargetRegionKernelsSPMD; auto Remark = [&](OptimizationRemark OR) {