diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -497,6 +497,16 @@ /// one we abort as the kernel is malformed. CallBase *KernelDeinitCB = nullptr; + /// A map from a function to its constant return value. If the value is + /// nullptr, the function cannot be folded. + SmallDenseMap FoldableFunctions; + + /// Flag to indicate if the associated function is a kernel entry. + bool IsKernelEntry = false; + + /// State to track what kernel entries can reach the associated function. + BooleanStateWithPtrSetVector ReachingKernelEntries; + /// Abstract State interface ///{ @@ -517,6 +527,7 @@ IsAtFixpoint = true; SPMDCompatibilityTracker.indicatePessimisticFixpoint(); ReachedUnknownParallelRegions.indicatePessimisticFixpoint(); + ReachingKernelEntries.indicatePessimisticFixpoint(); return ChangeStatus::CHANGED; } @@ -535,11 +546,15 @@ return false; if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions) return false; - if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions) + + if (ReachingKernelEntries != RHS.ReachingKernelEntries) + return false; + + if (FoldableFunctions != RHS.FoldableFunctions) return false; + return true; } - /// Return empty set as the best state of potential values. static KernelInfoState getBestState() { return KernelInfoState(true); } @@ -566,6 +581,7 @@ SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker; ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions; ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions; + ReachingKernelEntries ^= KIS.ReachingKernelEntries; return *this; } @@ -2725,6 +2741,10 @@ if (!OMPInfoCache.Kernels.count(Fn)) return; + // Add itself to the reaching kernel and set IsKernelEntry. + ReachingKernelEntries.insert(Fn); + IsKernelEntry = true; + OMPInformationCache::RuntimeFunctionInfo &InitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; OMPInformationCache::RuntimeFunctionInfo &DeinitRFI = @@ -2826,21 +2846,34 @@ /// Modify the IR based on the KernelInfoState as the fixpoint iteration is /// finished now. ChangeStatus manifest(Attributor &A) override { + ChangeStatus Change = ChangeStatus::UNCHANGED; + + // Fold all valid foldable functions + for (std::pair &P : FoldableFunctions) { + if (P.second == nullptr) + continue; + + A.changeValueAfterManifest(*P.first, *P.second); + A.deleteAfterManifest(*P.first); + + Change = ChangeStatus::CHANGED; + } + // If we are not looking at a kernel with __kmpc_target_init and // __kmpc_target_deinit call we cannot actually manifest the information. if (!KernelInitCB || !KernelDeinitCB) - return ChangeStatus::UNCHANGED; + return Change; // Known SPMD-mode kernels need no manifest changes. if (SPMDCompatibilityTracker.isKnown()) - return ChangeStatus::UNCHANGED; + return Change; // If we can we change the execution mode to SPMD-mode otherwise we build a // custom state machine. if (!changeToSPMDMode(A)) - buildCustomStateMachine(A); + Change = Change | buildCustomStateMachine(A); - return ChangeStatus::CHANGED; + return Change; } bool changeToSPMDMode(Attributor &A) { @@ -3203,6 +3236,12 @@ if (!A.checkForAllReadWriteInstructions(CheckRWInst, *this)) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + updateReachingKernelEntries(A); + + // Update info regarding execution mode. + if (!ReachingKernelEntries.isAssumed()) + updateSPMDFolding(A); + // Callback to check a call instruction. auto CheckCallInst = [&](Instruction &I) { auto &CB = cast(I); @@ -3210,6 +3249,19 @@ *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); if (CBAA.getState().isValidState()) getState() ^= CBAA.getState(); + + Function *Callee = CB.getCalledFunction(); + if (Callee) { + // We need to propagate information to the callee, but since the + // construction of AA always starts with kernel entries, we have to + // create AAKernelInfoFunction for all called functions. However, here + // the caller doesn't depend on the callee. + // TODO: We might want to change the dependence here later if we need + // information from callee to caller. + A.getOrCreateAAFor(IRPosition::function(*Callee), this, + DepClassTy::NONE); + } + return true; }; @@ -3219,6 +3271,77 @@ return StateBefore == getState() ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; } + +private: + /// Update info regarding reaching kernels. + void updateReachingKernelEntries(Attributor &A) { + if (!IsKernelEntry) { + auto PredCallSite = [&](AbstractCallSite ACS) { + Function *Caller = ACS.getInstruction()->getFunction(); + + assert(Caller && "Caller is nullptr"); + + auto &CAA = + A.getOrCreateAAFor(IRPosition::function(*Caller)); + if (CAA.isValidState()) { + ReachingKernelEntries ^= CAA.ReachingKernelEntries; + return true; + } + + // We lost track of the caller of the associated function, any kernel + // could reach now. + ReachingKernelEntries.indicatePessimisticFixpoint(); + + return true; + }; + + bool AllCallSitesKnown; + if (!A.checkForAllCallSites(PredCallSite, *this, + true /* RequireAllCallSites */, + AllCallSitesKnown)) + ReachingKernelEntries.indicatePessimisticFixpoint(); + } + } + + /// Update information regarding folding SPMD mode function calls. + void updateSPMDFolding(Attributor &A) { + unsigned Count = 0; + + for (Kernel K : ReachingKernelEntries) { + auto &AA = A.getAAFor(*this, IRPosition::function(*K), + DepClassTy::REQUIRED); + assert(AA.isValidState() && "AA should be valid here"); + if (AA.SPMDCompatibilityTracker.isAssumed() || + AA.SPMDCompatibilityTracker.isKnown()) + ++Count; + } + + // Assume reaching kernels are in a mixture of SPMD and non-SPMD mode. + // Update all function calls to __kmpc_is_spmd_exec_mode to nullptr. + Constant *C = nullptr; + + auto &Ctx = getAnchorValue().getContext(); + + if (Count == 0) { + // All reaching kernels are in non-SPMD mode. Update all function + // calls to __kmpc_is_spmd_exec_mode to 0. + C = ConstantInt::get(Type::getInt8Ty(Ctx), 0); + } else if (Count == ReachingKernelEntries.size()) { + // All reaching kernels are in SPMD mode. Update all function calls to + // __kmpc_is_spmd_exec_mode to 1. + C = ConstantInt::get(Type::getInt8Ty(Ctx), 1); + } + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + OMPInformationCache::RuntimeFunctionInfo &IsSPMDExecModeRFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode]; + + for (std::pair &P : FoldableFunctions) { + CallBase *CB = P.first; + if (CB->getCalledFunction() == IsSPMDExecModeRFI.Declaration) + P.second = C; + } + } }; /// The call site kernel info abstract attribute, basically, what can we say