diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -497,6 +497,21 @@ /// one we abort as the kernel is malformed. CallBase *KernelDeinitCB = nullptr; + /// A map from a function to its constant return value. If the value is + /// nullptr, the function cannot be folded. + SmallDenseMap FoldableFunctions; + + /// Flag to indicate if the associated function is in SPMD mode. This variable + /// is only set to true if it is indeed in a SPMD mode, not from being + /// spmdized. + bool IsSPMDMode = false; + + /// Flag to indicate if the associated function is a kernel entry. + bool IsKernelEntry = false; + + /// State to track what kernel entries can reach the associated function. + BooleanStateWithPtrSetVector ReachingKernelEntries; + /// Abstract State interface ///{ @@ -517,6 +532,7 @@ IsAtFixpoint = true; SPMDCompatibilityTracker.indicatePessimisticFixpoint(); ReachedUnknownParallelRegions.indicatePessimisticFixpoint(); + ReachingKernelEntries.indicatePessimisticFixpoint(); return ChangeStatus::CHANGED; } @@ -535,11 +551,15 @@ return false; if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions) return false; - if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions) + + if (ReachingKernelEntries != RHS.ReachingKernelEntries) return false; + + if (FoldableFunctions != RHS.FoldableFunctions) + return false; + return true; } - /// Return empty set as the best state of potential values. static KernelInfoState getBestState() { return KernelInfoState(true); } @@ -566,6 +586,7 @@ SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker; ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions; ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions; + ReachingKernelEntries ^= KIS.ReachingKernelEntries; return *this; } @@ -2826,21 +2847,33 @@ /// Modify the IR based on the KernelInfoState as the fixpoint iteration is /// finished now. ChangeStatus manifest(Attributor &A) override { - // If we are not looking at a kernel with __kmpc_target_init and - // __kmpc_target_deinit call we cannot actually manifest the information. - if (!KernelInitCB || !KernelDeinitCB) - return ChangeStatus::UNCHANGED; + ChangeStatus Change = ChangeStatus::UNCHANGED; - // Known SPMD-mode kernels need no manifest changes. - if (SPMDCompatibilityTracker.isKnown()) - return ChangeStatus::UNCHANGED; + // This is a somewhat unrelated modification. We basically flatten the + // function that was reached from a kernel completely by asking the inliner + // to inline everything it can. This should live in a separate AA though as + // it should also run on parallel regions and other GPU functions. + auto CheckCallInst = [&](Instruction &I) { + auto &CB = cast(I); + CB.addAttribute(AttributeList::FunctionIndex, Attribute::AlwaysInline); + return true; + }; + A.checkForAllCallLikeInstructions(CheckCallInst, *this); + + // Fold all valid foldable functions + for (std::pair &P : FoldableFunctions) { + if (P.second == nullptr) + continue; - // If we can we change the execution mode to SPMD-mode otherwise we build a - // custom state machine. - if (!changeToSPMDMode(A)) - buildCustomStateMachine(A); + A.changeValueAfterManifest(*P.first, *P.second); + A.deleteAfterManifest(*P.first); - return ChangeStatus::CHANGED; + Change = ChangeStatus::CHANGED; + } + + Change = Change | buildCustomStateMachine(A); + + return Change; } bool changeToSPMDMode(Attributor &A) { @@ -3203,6 +3236,12 @@ if (!A.checkForAllReadWriteInstructions(CheckRWInst, *this)) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + updateReachingKernelEntries(A); + + // Update info regarding execution mode. + if (!ReachingKernelEntries.isAssumed()) + updateSPMDFolding(A); + // Callback to check a call instruction. auto CheckCallInst = [&](Instruction &I) { auto &CB = cast(I); @@ -3210,6 +3249,19 @@ *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); if (CBAA.getState().isValidState()) getState() ^= CBAA.getState(); + + Function *Callee = CB.getCalledFunction(); + if (Callee) { + // We need to propagate information to the callee, but since the + // construction of AA always starts with kernel entries, we have to + // create AAKernelInfoFunction for all called functions. However, here + // the caller doesn't depend on the callee. + // TODO: We might want to change the dependence here later if we need + // information from callee to caller. + A.getOrCreateAAFor(IRPosition::function(*Callee), this, + DepClassTy::NONE); + } + return true; }; @@ -3219,6 +3271,115 @@ return StateBefore == getState() ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; } + +private: + /// Do some setup if the associated function is a kernel entry. + void setupForKernelEntry(Attributor &A, Function *F) { + // Add itself to the reaching kernel and set IsKernelEntry. + ReachingKernelEntries.insert(F); + IsKernelEntry = true; + + // Check the callsite of function __kmpc_target_init to get the argument of + // SPMD mode and set IsSPMDMode accordingly. + CallBase *CB = nullptr; + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + OMPInformationCache::RuntimeFunctionInfo &TargetInitRFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; + + TargetInitRFI.foreachUse( + [&](Use &U, Function &Caller) { + assert(CB == nullptr && "__kmpc_target_init is called more than " + "once in one kernel entry"); + + assert(dyn_cast(U.getUser()) && + "Use of __kmpc_target_init is not a CallBase"); + + CB = cast(U.getUser()); + + return true; + }, + F); + + assert(CB && "Call to __kmpc_target_init is missing"); + + constexpr const int InitIsSPMDArgNo = 1; + + ConstantInt *IsSPMDArg = + dyn_cast(CB->getArgOperand(InitIsSPMDArgNo)); + + if (IsSPMDArg && !IsSPMDArg->isZero()) + IsSPMDMode = true; + } + + /// Update info regarding reaching kernels. + void updateReachingKernelEntries(Attributor &A) { + if (!IsKernelEntry) { + auto PredCallSite = [&](AbstractCallSite ACS) { + Function *Caller = ACS.getInstruction()->getFunction(); + + assert(Caller && "Caller is nullptr"); + + auto &CAA = + A.getOrCreateAAFor(IRPosition::function(*Caller)); + if (CAA.isValidState()) { + ReachingKernelEntries ^= CAA.ReachingKernelEntries; + return true; + } + + // We lost track of the caller of the associated function, any kernel + // could reach now. + ReachingKernelEntries.indicatePessimisticFixpoint(); + + return true; + }; + + bool AllCallSitesKnown; + if (!A.checkForAllCallSites(PredCallSite, *this, + true /* RequireAllCallSites */, + AllCallSitesKnown)) + ReachingKernelEntries.indicatePessimisticFixpoint(); + } + } + + /// Update information regarding folding SPMD mode function calls. + void updateSPMDFolding(Attributor &A) { + unsigned Count = 0; + + for (Kernel K : ReachingKernelEntries) { + auto &AA = A.getAAFor(*this, IRPosition::function(*K), + DepClassTy::REQUIRED); + assert(AA.isValidState() && "AA should be valid here"); + if (AA.SPMDCompatibilityTracker.isAssumed()) + ++Count; + } + + // Assume reaching kernels are in a mixture of SPMD and non-SPMD mode. + // Update all function calls to __kmpc_is_spmd_exec_mode to nullptr. + Constant *C = nullptr; + + auto &Ctx = getAnchorValue().getContext(); + + if (Count == 0) { + // All reaching kernels are in non-SPMD mode. Update all function + // calls to __kmpc_is_spmd_exec_mode to 0. + C = ConstantInt::get(Type::getInt8Ty(Ctx), 0); + } else if (Count == ReachingKernelEntries.size()) { + // All reaching kernels are in SPMD mode. Update all function calls to + // __kmpc_is_spmd_exec_mode to 1. + C = ConstantInt::get(Type::getInt8Ty(Ctx), 1); + } + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + OMPInformationCache::RuntimeFunctionInfo &IsSPMDExecModeRFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode]; + + for (std::pair &P : FoldableFunctions) { + CallBase *CB = P.first; + if (CB->getCalledFunction() == IsSPMDExecModeRFI.Declaration) + P.second = C; + } + } }; /// The call site kernel info abstract attribute, basically, what can we say