diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -497,6 +497,28 @@ /// one we abort as the kernel is malformed. CallBase *KernelDeinitCB = nullptr; + /// A map from a function to its constant return value. If the value is + /// nullptr, the function cannot be folded. + SmallDenseMap FoldableFunctions; + + /// Flag to indicate that we may reach a parallel region that is not tracked + /// in the ParallelRegions set above. + SmallPtrSet UnknownParallelRegions; + + /// Flag to indicate that we may reach a kernel entry that is not tracked. + bool MaybeReachedByUnknownKernel = false; + + /// Flag to indicate if the associated function is in SPMD mode. This variable + /// is only set to true if it is indeed in a SPMD mode, not from being + /// spmdized. + bool IsSPMDMode = false; + + /// Flag to indicate if the associated function is a kernel entry. + bool IsKernelEntry = false; + + /// Kernels that can reach the associated function. + SmallPtrSet ReachingKernels; + /// Abstract State interface ///{ @@ -517,6 +539,8 @@ IsAtFixpoint = true; SPMDCompatibilityTracker.indicatePessimisticFixpoint(); ReachedUnknownParallelRegions.indicatePessimisticFixpoint(); + UnknownParallelRegions.insert(nullptr); + MaybeReachedByUnknownKernel = true; return ChangeStatus::CHANGED; } @@ -535,11 +559,15 @@ return false; if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions) return false; - if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions) + + if (ReachingKernels != RHS.ReachingKernels) return false; + + if (FoldableFunctions != RHS.FoldableFunctions) + return false; + return true; } - /// Return empty set as the best state of potential values. static KernelInfoState getBestState() { return KernelInfoState(true); } @@ -566,6 +594,9 @@ SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker; ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions; ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions; + UnknownParallelRegions.insert(KIS.UnknownParallelRegions.begin(), + KIS.UnknownParallelRegions.end()); + MaybeReachedByUnknownKernel |= KIS.MaybeReachedByUnknownKernel; return *this; } @@ -2826,21 +2857,33 @@ /// Modify the IR based on the KernelInfoState as the fixpoint iteration is /// finished now. ChangeStatus manifest(Attributor &A) override { - // If we are not looking at a kernel with __kmpc_target_init and - // __kmpc_target_deinit call we cannot actually manifest the information. - if (!KernelInitCB || !KernelDeinitCB) - return ChangeStatus::UNCHANGED; + ChangeStatus Change = ChangeStatus::UNCHANGED; - // Known SPMD-mode kernels need no manifest changes. - if (SPMDCompatibilityTracker.isKnown()) - return ChangeStatus::UNCHANGED; + // This is a somewhat unrelated modification. We basically flatten the + // function that was reached from a kernel completely by asking the inliner + // to inline everything it can. This should live in a separate AA though as + // it should also run on parallel regions and other GPU functions. + auto CheckCallInst = [&](Instruction &I) { + auto &CB = cast(I); + CB.addAttribute(AttributeList::FunctionIndex, Attribute::AlwaysInline); + return true; + }; + A.checkForAllCallLikeInstructions(CheckCallInst, *this); - // If we can we change the execution mode to SPMD-mode otherwise we build a - // custom state machine. - if (!changeToSPMDMode(A)) - buildCustomStateMachine(A); + // Fold all valid foldable functions + for (std::pair &P : FoldableFunctions) { + if (P.second == nullptr) + continue; - return ChangeStatus::CHANGED; + A.changeValueAfterManifest(*P.first, *P.second); + A.deleteAfterManifest(*P.first); + + Change = ChangeStatus::CHANGED; + } + + Change = Change | buildCustomStateMachine(A); + + return Change; } bool changeToSPMDMode(Attributor &A) { @@ -3203,6 +3246,12 @@ if (!A.checkForAllReadWriteInstructions(CheckRWInst, *this)) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + updateReachingKernels(A); + + // Update info regarding execution mode. + if (!MaybeReachedByUnknownKernel) + updateSPMDFolding(A); + // Callback to check a call instruction. auto CheckCallInst = [&](Instruction &I) { auto &CB = cast(I); @@ -3210,6 +3259,32 @@ *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); if (CBAA.getState().isValidState()) getState() ^= CBAA.getState(); + Function *Callee = CB.getCalledFunction(); + if (Callee) { + // We need to propagate information to the callee, but since the + // construction of AA always starts with kernel entries, we have to + // create AAKernelInfoFunction for all called functions. However, here + // the caller doesn't depend on the callee. + // TODO: We might want to change the dependence here later if we need + // information from callee to caller. + A.getOrCreateAAFor(IRPosition::function(*Callee), this, + DepClassTy::NONE); + + auto &CBAA = A.getAAFor( + *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); + if (CBAA.getState().isValidState()) { + getState() ^= CBAA.getState(); + return true; + } + } + + // The callee is not known, not ipo-amendable (e.g., due to linkage), or + // we can for some other reason not analyze it. If we cannot gather + // information, e.g., the state of the AAKernelInfo we got is invalid, we + // don't have to completely give up here. It basically means we have no + // idea what the effects of the call might be, for now the worst that can + // happen are unknown parallel regions hide in the callee. + UnknownParallelRegions.insert(&CB); return true; }; @@ -3219,6 +3294,116 @@ return StateBefore == getState() ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; } + +private: + /// Do some setup if the associated function is a kernel entry. + void setupForKernelEntry(Attributor &A, Function *F) { + // Add itself to the reaching kernel and set IsKernelEntry. + ReachingKernels.insert(F); + IsKernelEntry = true; + + // Check the callsite of function __kmpc_target_init to get the argument of + // SPMD mode and set IsSPMDMode accordingly. + CallBase *CB = nullptr; + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + OMPInformationCache::RuntimeFunctionInfo &TargetInitRFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; + + TargetInitRFI.foreachUse( + [&](Use &U, Function &Caller) { + assert(CB == nullptr && "__kmpc_target_init is called more than " + "once in one kernel entry"); + + assert(dyn_cast(U.getUser()) && + "Use of __kmpc_target_init is not a CallBase"); + + CB = cast(U.getUser()); + + return true; + }, + F); + + assert(CB && "Call to __kmpc_target_init is missing"); + + constexpr const int InitIsSPMDArgNo = 1; + + ConstantInt *IsSPMDArg = + dyn_cast(CB->getArgOperand(InitIsSPMDArgNo)); + + if (IsSPMDArg && !IsSPMDArg->isZero()) + IsSPMDMode = true; + } + + /// Update info regarding reaching kernels. + void updateReachingKernels(Attributor &A) { + if (!IsKernelEntry) { + auto PredCallSite = [&](AbstractCallSite ACS) { + Function *Caller = ACS.getInstruction()->getFunction(); + + assert(Caller && "Caller is nullptr"); + + auto &CAA = + A.getOrCreateAAFor(IRPosition::function(*Caller)); + if (CAA.isValidState()) { + ReachingKernels.insert(CAA.ReachingKernels.begin(), + CAA.ReachingKernels.end()); + return true; + } + + // We lost track of the caller of the associated function, any kernel + // could reach now. + MaybeReachedByUnknownKernel = true; + + return true; + }; + + bool AllCallSitesKnown; + if (!A.checkForAllCallSites(PredCallSite, *this, + true /* RequireAllCallSites */, + AllCallSitesKnown)) + MaybeReachedByUnknownKernel = true; + } + } + + /// Update information regarding folding SPMD mode function calls. + void updateSPMDFolding(Attributor &A) { + unsigned Count = 0; + + for (Kernel K : ReachingKernels) { + auto &AA = A.getAAFor(*this, IRPosition::function(*K), + DepClassTy::REQUIRED); + assert(AA.isValidState() && "AA should be valid here"); + if (AA.IsSPMDMode) + ++Count; + } + + // Assume reaching kernels are in a mixture of SPMD and non-SPMD mode. + // Update all function calls to __kmpc_is_spmd_exec_mode to nullptr. + Constant *C = nullptr; + + auto &Ctx = getAnchorValue().getContext(); + + if (Count == 0) { + // All reaching kernels are in non-SPMD mode. Update all function + // calls to __kmpc_is_spmd_exec_mode to 0. + C = ConstantInt::get(Type::getInt8Ty(Ctx), 0); + } else if (Count == ReachingKernels.size()) { + // All reaching kernels are in SPMD mode. Update all function calls to + // __kmpc_is_spmd_exec_mode to 1. + C = ConstantInt::get(Type::getInt8Ty(Ctx), 1); + } + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + OMPInformationCache::RuntimeFunctionInfo &IsSPMDExecModeRFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode]; + + for (std::pair &P : FoldableFunctions) { + CallBase *CB = P.first; + if (CB->getCalledFunction() == IsSPMDExecModeRFI.Declaration) + P.second = C; + } + } }; /// The call site kernel info abstract attribute, basically, what can we say