diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -497,6 +497,12 @@ /// one we abort as the kernel is malformed. CallBase *KernelDeinitCB = nullptr; + /// Flag to indicate if the associated function is a kernel entry. + bool IsKernelEntry = false; + + /// State to track what kernel entries can reach the associated function. + BooleanStateWithPtrSetVector ReachingKernelEntries; + /// Abstract State interface ///{ @@ -537,6 +543,8 @@ return false; if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions) return false; + if (ReachingKernelEntries != RHS.ReachingKernelEntries) + return false; return true; } @@ -2725,6 +2733,10 @@ if (!OMPInfoCache.Kernels.count(Fn)) return; + // Add itself to the reaching kernel and set IsKernelEntry. + ReachingKernelEntries.insert(Fn); + IsKernelEntry = true; + OMPInformationCache::RuntimeFunctionInfo &InitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; OMPInformationCache::RuntimeFunctionInfo &DeinitRFI = @@ -3206,6 +3218,9 @@ if (!A.checkForAllReadWriteInstructions(CheckRWInst, *this)) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + if (!IsKernelEntry) + updateReachingKernelEntries(A); + // Callback to check a call instruction. auto CheckCallInst = [&](Instruction &I) { auto &CB = cast(I); @@ -3213,6 +3228,19 @@ *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL); if (CBAA.getState().isValidState()) getState() ^= CBAA.getState(); + + Function *Callee = CB.getCalledFunction(); + if (Callee) { + // We need to propagate information to the callee, but since the + // construction of AA always starts with kernel entries, we have to + // create AAKernelInfoFunction for all called functions. However, here + // the caller doesn't depend on the callee. + // TODO: We might want to change the dependence here later if we need + // information from callee to caller. + A.getOrCreateAAFor(IRPosition::function(*Callee), this, + DepClassTy::NONE); + } + return true; }; @@ -3222,6 +3250,35 @@ return StateBefore == getState() ? ChangeStatus::UNCHANGED : ChangeStatus::CHANGED; } + +private: + /// Update info regarding reaching kernels. + void updateReachingKernelEntries(Attributor &A) { + auto PredCallSite = [&](AbstractCallSite ACS) { + Function *Caller = ACS.getInstruction()->getFunction(); + + assert(Caller && "Caller is nullptr"); + + auto &CAA = + A.getOrCreateAAFor(IRPosition::function(*Caller)); + if (CAA.isValidState()) { + ReachingKernelEntries ^= CAA.ReachingKernelEntries; + return true; + } + + // We lost track of the caller of the associated function, any kernel + // could reach now. + ReachingKernelEntries.indicatePessimisticFixpoint(); + + return true; + }; + + bool AllCallSitesKnown; + if (!A.checkForAllCallSites(PredCallSite, *this, + true /* RequireAllCallSites */, + AllCallSitesKnown)) + ReachingKernelEntries.indicatePessimisticFixpoint(); + } }; /// The call site kernel info abstract attribute, basically, what can we say @@ -3290,6 +3347,7 @@ switch (RF) { // All the functions we know are compatible with SPMD mode. case OMPRTL___kmpc_is_spmd_exec_mode: + return; case OMPRTL___kmpc_for_static_fini: case OMPRTL___kmpc_global_thread_num: case OMPRTL___kmpc_single: @@ -3368,6 +3426,171 @@ } }; +struct AAFoldRuntimeCall + : public StateWrapper { + using Base = StateWrapper; + + AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {} + + /// Statistics are tracked as part of manifest for now. + void trackStatistics() const override {} + + /// See AbstractAttribute::getAsStr() + const std::string getAsStr() const override { + if (!isValidState()) + return ""; + return std::string("#RKs: ") + std::to_string(ReachingKernelEntries.size()); + } + + /// Create an abstract attribute biew for the position \p IRP. + static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP, + Attributor &A); + + /// See AbstractAttribute::getName() + const std::string getName() const override { return "AAFoldRuntimeCall"; } + + /// See AbstractAttribute::getIdAddr() + const char *getIdAddr() const override { return &ID; } + + /// This function should return true if the type of the \p AA is + /// AAFoldRuntimeCall + static bool classof(const AbstractAttribute *AA) { + return (AA->getIdAddr() == &ID); + } + + static const char ID; +}; + +struct AAFoldRuntimeCallCallSite : AAFoldRuntimeCall { + AAFoldRuntimeCallCallSite(const IRPosition &IRP, Attributor &A) + : AAFoldRuntimeCall(IRP, A) {} + + void initialize(Attributor &A) override { + Function *Callee = getAssociatedFunction(); + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); + assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() && + "Expected a known OpenMP runtime function"); + + RFKind = It->getSecond(); + + CallBase &CB = cast(getAssociatedValue()); + A.registerSimplificationCallback( + IRPosition::callsite_function(CB), + [&](const IRPosition &IRP, const AbstractAttribute *AA, + bool &UsedAssumedInformation) -> Optional { + if (!isAtFixpoint()) { + UsedAssumedInformation = true; + if (AA) + A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); + } + return SimplifiedValue; + }); + } + + ChangeStatus updateImpl(Attributor &A) override { + ChangeStatus Changed = ChangeStatus::UNCHANGED; + + switch (RFKind) { + case OMPRTL___kmpc_is_spmd_exec_mode: + Changed = Changed | foldIsSPMDExecMode(A); + break; + default: + llvm_unreachable("Unhandled OpenMP runtime function!"); + } + + return Changed; + } + + /// Fold __kmpc_is_spmd_exec_mode into a constant if possible. + ChangeStatus foldIsSPMDExecMode(Attributor &A) { + unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0; + unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0; + auto &CallerKernelInfoAA = A.getAAFor( + *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED); + + for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) { + auto &AA = A.getAAFor(*this, IRPosition::function(*K), + DepClassTy::REQUIRED); + + if (!AA.isValidState()) { + SimplifiedValue = nullptr; + indicatePessimisticFixpoint(); + + return ChangeStatus::CHANGED; + } + + if (AA.SPMDCompatibilityTracker.isAssumed()) { + if (AA.SPMDCompatibilityTracker.isAtFixpoint()) + ++KnownSPMDCount; + else + ++AssumedSPMDCount; + } else { + if (AA.SPMDCompatibilityTracker.isAtFixpoint()) + ++KnownNonSPMDCount; + else + ++AssumedNonSPMDCount; + } + } + + if (KnownSPMDCount && KnownNonSPMDCount) { + SimplifiedValue = nullptr; + indicatePessimisticFixpoint(); + + return ChangeStatus::CHANGED; + } + if (AssumedSPMDCount || AssumedNonSPMDCount) { + assert( + !isAtFixpoint() && !SimplifiedValue.hasValue() && + "Simplified value should be null while we use assumed information!"); + return ChangeStatus::UNCHANGED; + } + + auto &Ctx = getAnchorValue().getContext(); + if (KnownSPMDCount) { + assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 && + AssumedSPMDCount == 0 && "Expected only SPMD kernels!"); + // All reaching kernels are in SPMD mode. Update all function calls to + // __kmpc_is_spmd_exec_mode to 1. + SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true); + } else { + assert(KnownSPMDCount == 0 && AssumedNonSPMDCount == 0 && + AssumedSPMDCount == 0 && "Expected only non-SPMD kernels!"); + // All reaching kernels are in non-SPMD mode. Update all function + // calls to __kmpc_is_spmd_exec_mode to 0. + SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false); + } + + indicateOptimisticFixpoint(); + + return ChangeStatus::CHANGED; + } + + ChangeStatus manifest(Attributor &A) override { + ChangeStatus Changed = ChangeStatus::UNCHANGED; + + if (!SimplifiedValue.hasValue()) { + A.deleteAfterManifest(*getCtxI()); + } else if (*SimplifiedValue != nullptr) { + Instruction &CB = *getCtxI(); + A.changeValueAfterManifest(CB, **SimplifiedValue); + A.deleteAfterManifest(CB); + } + + return Changed; + } + +private: + /// An optional value the associated value is assumed to fold to. That is, we + /// assume the associated value (which is a call) can be replaced by this + /// simplified value. + Optional SimplifiedValue; + + /// The runtime function kind of the callee of the associated call site. + RuntimeFunction RFKind; +}; + } // namespace void OpenMPOpt::registerAAs(bool IsModulePass) { @@ -3384,6 +3607,18 @@ IRPosition::function(*Kernel), /* QueryingAA */ nullptr, DepClassTy::NONE, /* ForceUpdate */ false, /* UpdateAfterInit */ false); + + auto &IsSPMDRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode]; + IsSPMDRFI.foreachUse(SCC, [&](Use &U, Function &) { + CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsSPMDRFI); + if (!CI) + return false; + A.getOrCreateAAFor( + IRPosition::callsite_function(*CI), /* QueryingAA */ nullptr, + DepClassTy::NONE, /* ForceUpdate */ false, + /* UpdateAfterInit */ false); + return false; + }); } // Create CallSite AA for all Getters. @@ -3427,6 +3662,7 @@ const char AAKernelInfo::ID = 0; const char AAExecutionDomain::ID = 0; const char AAHeapToShared::ID = 0; +const char AAFoldRuntimeCall::ID = 0; AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP, Attributor &A) { @@ -3518,6 +3754,26 @@ return *AA; } +AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP, + Attributor &A) { + AAFoldRuntimeCall *AA = nullptr; + switch (IRP.getPositionKind()) { + case IRPosition::IRP_INVALID: + case IRPosition::IRP_FLOAT: + case IRPosition::IRP_ARGUMENT: + case IRPosition::IRP_RETURNED: + case IRPosition::IRP_CALL_SITE_RETURNED: + case IRPosition::IRP_CALL_SITE_ARGUMENT: + case IRPosition::IRP_FUNCTION: + llvm_unreachable("KernelInfo can only be created for call site position!"); + case IRPosition::IRP_CALL_SITE: + AA = new (A.Allocator) AAFoldRuntimeCallCallSite(IRP, A); + break; + } + + return *AA; +} + PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { if (!containsOpenMP(M)) return PreservedAnalyses::all();