diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -448,13 +448,33 @@ /// one we abort as the kernel is malformed. CallBase *KernelDeinitCB = nullptr; + /// A map from a function to its constant return value. If the value is + /// nullptr, the function cannot be folded. + SmallDenseMap FoldableFunctions; + /// Flag to indicate that we may reach a parallel region that is not tracked /// in the ParallelRegions set above. SmallPtrSet UnknownParallelRegions; + /// Flag to indicate that we may reach a kernel entry that is not tracked. + bool MaybeReachedByUnknownKernel = false; + /// Flag to indicate if the associated function can be executed in SPMD mode. + /// Note that it could be set to false if the function is already in SPMD + /// mode. Use \p willBeExecutedInSPMDMode to query the status instead. bool IsSPMDCompatible = true; + /// Flag to indicate if the associated function is in SPMD mode. This is + /// different from \p IsSPMDCompatible. This variable is only set to true if + /// it is indeed in a SPMD mode, not from being spmdized. + bool IsSPMDMode = false; + + /// Flag to indicate if the associated function is a kernel entry. + bool IsKernelEntry = false; + + /// Kernels that can reach the associated function. + SmallPtrSet ReachingKernels; + /// Abstract State interface ///{ @@ -475,6 +495,7 @@ IsAtFixpoint = true; IsSPMDCompatible = false; UnknownParallelRegions.insert(nullptr); + MaybeReachedByUnknownKernel = true; return ChangeStatus::CHANGED; } @@ -492,7 +513,17 @@ if ((UnknownParallelRegions != RHS.UnknownParallelRegions) || (IsSPMDCompatible != RHS.IsSPMDCompatible)) return false; - return ParallelRegions == RHS.ParallelRegions; + + if (ParallelRegions != RHS.ParallelRegions) + return false; + + if (ReachingKernels != RHS.ReachingKernels) + return false; + + if (FoldableFunctions != RHS.FoldableFunctions) + return false; + + return true; } /// Return empty set as the best state of potential values. @@ -520,6 +551,7 @@ } UnknownParallelRegions.insert(KIS.UnknownParallelRegions.begin(), KIS.UnknownParallelRegions.end()); + MaybeReachedByUnknownKernel |= KIS.MaybeReachedByUnknownKernel; IsSPMDCompatible &= KIS.IsSPMDCompatible; ParallelRegions.insert(KIS.ParallelRegions.begin(), KIS.ParallelRegions.end()); @@ -2682,6 +2714,8 @@ /// Modify the IR based on the KernelInfoState as the fixpoint iteration is /// finished now. ChangeStatus manifest(Attributor &A) override { + ChangeStatus Change = ChangeStatus::UNCHANGED; + // This is a somewhat unrelated modification. We basically flatten the // function that was reached from a kernel completely by asking the inliner // to inline everything it can. This should live in a separate AA though as @@ -2693,7 +2727,22 @@ }; A.checkForAllCallLikeInstructions(CheckCallInst, *this); - return buildCustomStateMachine(A); + // Fold all valid foldable functions + for (std::pair &P : FoldableFunctions) { + if (P.second == nullptr) + continue; + + for (User *U : P.first->users()) { + A.changeValueAfterManifest(*U, *P.second); + A.deleteAfterManifest(*P.first); + } + + Change = ChangeStatus::CHANGED; + } + + Change = Change | buildCustomStateMachine(A); + + return Change; } ChangeStatus buildCustomStateMachine(Attributor &A) { @@ -2717,7 +2766,7 @@ auto &Ctx = getAnchorValue().getContext(); // First check if we can go to SPMD-mode, that is the best option. - if (canBeExecutedInSPMDMode() && IsSPMD && IsSPMD->isZero()) { + if (willBeExecutedInSPMDMode() && IsSPMD && IsSPMD->isZero()) { // Indicate we use SPMD mode now. A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo), *ConstantInt::getBool(Ctx, 1)); @@ -2955,8 +3004,11 @@ return ChangeStatus::CHANGED; } - /// Returns true if value is assumed to be tracked. - bool canBeExecutedInSPMDMode() const { return IsSPMDCompatible; } + /// Returns true if the associated function will be executed in SPMD mode, no + /// matter whether it is initially in SPMD mode or spmdized. + bool willBeExecutedInSPMDMode() const { + return IsSPMDMode || IsSPMDCompatible; + } /// Statistics are tracked as part of manifest for now. void trackStatistics() const override {} @@ -2965,9 +3017,9 @@ const std::string getAsStr() const override { if (!isValidState()) return ""; - return std::string(canBeExecutedInSPMDMode() ? "SPMD" : "generic") + std::string("#PRs: ") + std::to_string(ParallelRegions.size()) + + return std::string(willBeExecutedInSPMDMode() ? "SPMD" : "generic") + + std::string("#PRs: ") + std::to_string(ParallelRegions.size()) + ", #Unknown PRs: " + std::to_string(!UnknownParallelRegions.size()); -||||||| parent of 3665f4d7e56f (Manually rebase D102307, no tests) } /// Create an abstract attribute biew for the position \p IRP. @@ -2993,6 +3045,69 @@ AAKernelInfoFunction(const IRPosition &IRP, Attributor &A) : AAKernelInfo(IRP, A) {} + void initialize(Attributor &A) override { + AAKernelInfo::initialize(A); + + Function *F = getAssociatedFunction(); + + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + + OMPInformationCache::RuntimeFunctionInfo &IsSPMDExecModeRFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode]; + + // Since we assume all kernels are spmd compatible, we assume all calls to + // __kmpc_is_spmd_exec_mode can be folded to 1. + IsSPMDExecModeRFI.foreachUse( + [&](Use &U, Function &Caller) { + CallBase *CB = dyn_cast(U.getUser()); + assert(CB && "Use of __kmpc_is_spmd_exec_mode is not a CallBase"); + FoldableFunctions[CB] = + ConstantInt::get(Type::getInt8Ty(Caller.getContext()), 1); + return true; + }, + F); + + if (OMPInfoCache.Kernels.count(F) == 0) + return; + + // Starting this point, the associated function is a kernel entry. + // Add itself to the reaching kernel and set IsKernelEntry. + ReachingKernels.insert(F); + IsKernelEntry = true; + + // Check the callsite of function __kmpc_target_init to get the argument of + // SPMD mode and set IsSPMDMode accordingly. + CallBase *CB = nullptr; + + OMPInformationCache::RuntimeFunctionInfo &TargetInitRFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; + + TargetInitRFI.foreachUse( + [&](Use &U, Function &Caller) { + assert(CB == nullptr && "__kmpc_target_init is called more than " + "once in one kernel entry"); + + assert(dyn_cast(U.getUser()) && + "Use of __kmpc_target_init is not a CallBase"); + + CB = cast(U.getUser()); + + return true; + }, + F); + + assert(CB && "Call to __kmpc_target_init is missing"); + + constexpr const int InitIsSPMDArgNo = 1; + + ConstantInt *IsSPMDArg = + dyn_cast(CB->getArgOperand(InitIsSPMDArgNo)); + + if (IsSPMDArg && !IsSPMDArg->isZero()) + IsSPMDMode = true; + } + /// Fixpoint iteration update function. Will be called every time a dependence /// changed its state (and in the beginning). ChangeStatus updateImpl(Attributor &A) override { @@ -3021,11 +3136,26 @@ !A.checkForAllReadWriteInstructions(CheckRWInst, *this)) IsSPMDCompatible = false; + updateReachingKernels(A); + + // Update info regarding execution mode. + if (!MaybeReachedByUnknownKernel) + updateSPMDFolding(A); + // Callback to check a call instruction. auto CheckCallInst = [&](Instruction &I) { auto &CB = cast(I); Function *Callee = CB.getCalledFunction(); if (Callee) { + // We need to propagate information to the callee, but since the + // construction of AA always starts with kernel entries, we have to + // create AAKernelInfoFunction for all called functions. However, here + // the caller doesn't depend on the callee. + // TODO: We might want to change the dependence here later if we need + // information from callee to caller. + A.getOrCreateAAFor(IRPosition::function(*Callee), this, + DepClassTy::NONE); + auto &CBAA = A.getAAFor( *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); if (CBAA.getState().isValidState()) { @@ -3051,6 +3181,77 @@ return StateBefore == getState() ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; } + +private: + /// Update info regarding reaching kernels. + void updateReachingKernels(Attributor &A) { + if (!IsKernelEntry) { + auto PredCallSite = [&](AbstractCallSite ACS) { + Function *Caller = ACS.getInstruction()->getFunction(); + + assert(Caller && "Caller is nullptr"); + + auto &CAA = + A.getOrCreateAAFor(IRPosition::function(*Caller)); + if (CAA.isValidState()) { + ReachingKernels.insert(CAA.ReachingKernels.begin(), + CAA.ReachingKernels.end()); + return true; + } + + // We lost track of the caller of the associated function, any kernel + // could reach now. + MaybeReachedByUnknownKernel = true; + + return true; + }; + + bool AllCallSitesKnown; + if (!A.checkForAllCallSites(PredCallSite, *this, + true /* RequireAllCallSites */, + AllCallSitesKnown)) + MaybeReachedByUnknownKernel = true; + } + } + + /// Update information regarding folding SPMD mode function calls. + void updateSPMDFolding(Attributor &A) { + unsigned Count = 0; + + for (Kernel K : ReachingKernels) { + auto &AA = A.getAAFor(*this, IRPosition::function(*K), + DepClassTy::REQUIRED); + assert(AA.isValidState() && "AA should be valid here"); + if (AA.willBeExecutedInSPMDMode()) + ++Count; + } + + // Assume reaching kernels are in a mixture of SPMD and non-SPMD mode. + // Update all function calls to __kmpc_is_spmd_exec_mode to nullptr. + Constant *C = nullptr; + + auto &Ctx = getAnchorValue().getContext(); + + if (Count == 0) { + // All reaching kernels are in non-SPMD mode. Update all function + // calls to __kmpc_is_spmd_exec_mode to 0. + C = ConstantInt::get(Type::getInt8Ty(Ctx), 0); + } else if (Count == ReachingKernels.size()) { + // All reaching kernels are in SPMD mode. Update all function calls to + // __kmpc_is_spmd_exec_mode to 1. + C = ConstantInt::get(Type::getInt8Ty(Ctx), 1); + } + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + OMPInformationCache::RuntimeFunctionInfo &IsSPMDExecModeRFI = + OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode]; + + for (std::pair &P : FoldableFunctions) { + CallBase *CB = P.first; + if (CB->getCalledFunction() == IsSPMDExecModeRFI.Declaration) + P.second = C; + } + } }; /// The call site kernel info abstract attribute, basically, what can we say