diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -486,7 +486,8 @@ BooleanStateWithPtrSetVector ReachedUnknownParallelRegions; /// State to track if we are in SPMD-mode, assumed or know, and why we decided - /// we cannot be. + /// we cannot be. If it is assumed, then RequiresFullRuntime should also be + /// false. BooleanStateWithPtrSetVector SPMDCompatibilityTracker; /// The __kmpc_target_init call in this kernel, if any. If we find more than @@ -2803,9 +2804,32 @@ return Val; }; + Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB = + [&](const IRPosition &IRP, const AbstractAttribute *AA, + bool &UsedAssumedInformation) -> Optional { + // IRP represents the "RequiresFullRuntime" argument of an + // __kmpc_target_init or __kmpc_target_deinit call. We will answer this + // one with the internal state of the SPMDCompatibilityTracker, so if + // generic then true, if SPMD then false. + if (!isValidState()) + return nullptr; + if (!SPMDCompatibilityTracker.isAtFixpoint()) { + if (AA) + A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); + UsedAssumedInformation = true; + } else { + UsedAssumedInformation = false; + } + auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(), + !SPMDCompatibilityTracker.isAssumed()); + return Val; + }; + constexpr const int InitIsSPMDArgNo = 1; constexpr const int DeinitIsSPMDArgNo = 1; constexpr const int InitUseStateMachineArgNo = 2; + constexpr const int InitRequiresFullRuntimeArgNo = 3; + constexpr const int DeinitRequiresFullRuntimeArgNo = 2; A.registerSimplificationCallback( IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo), StateMachineSimplifyCB); @@ -2815,6 +2839,14 @@ A.registerSimplificationCallback( IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo), IsSPMDModeSimplifyCB); + A.registerSimplificationCallback( + IRPosition::callsite_argument(*KernelInitCB, + InitRequiresFullRuntimeArgNo), + IsGenericModeSimplifyCB); + A.registerSimplificationCallback( + IRPosition::callsite_argument(*KernelDeinitCB, + DeinitRequiresFullRuntimeArgNo), + IsGenericModeSimplifyCB); // Check if we know we are in SPMD-mode already. ConstantInt *IsSPMDArg = @@ -2883,6 +2915,8 @@ const int InitIsSPMDArgNo = 1; const int DeinitIsSPMDArgNo = 1; const int InitUseStateMachineArgNo = 2; + const int InitRequiresFullRuntimeArgNo = 3; + const int DeinitRequiresFullRuntimeArgNo = 2; auto &Ctx = getAnchorValue().getContext(); A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo), @@ -2893,6 +2927,13 @@ A.changeUseAfterManifest( KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo), *ConstantInt::getBool(Ctx, 1)); + A.changeUseAfterManifest( + KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo), + *ConstantInt::getBool(Ctx, 0)); + A.changeUseAfterManifest( + KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo), + *ConstantInt::getBool(Ctx, 0)); + ++NumOpenMPTargetRegionKernelsSPMD; auto Remark = [&](OptimizationRemark OR) {