diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -67,6 +67,12 @@ "Number of OpenMP runtime function uses identified"); STATISTIC(NumOpenMPTargetRegionKernels, "Number of OpenMP target region entry points (=kernels) identified"); +STATISTIC(NumOpenMPTargetRegionKernelsSPMD, + "Number of OpenMP target region entry points (=kernels) executed in " + "SPMD-mode instead of generic-mode"); +STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachine, + "Number of OpenMP target region entry points (=kernels) executed in " + "generic-mode with customized state machines"); STATISTIC( NumOpenMPParallelRegionsReplacedInGPUStateMachine, "Number of OpenMP parallel regions replaced with ID in GPU state machines"); @@ -243,6 +249,11 @@ /// Map from functions to all uses of this runtime function contained in /// them. DenseMap> UsesMap; + + public: + /// Iterators for the uses of this runtime function. + decltype(UsesMap)::iterator begin() { return UsesMap.begin(); } + decltype(UsesMap)::iterator end() { return UsesMap.end(); } }; /// An OpenMP-IR-Builder instance @@ -253,6 +264,9 @@ RuntimeFunction::OMPRTL___last> RFIs; + /// Map from function declarations/definitions to their runtime enum type. + DenseMap RuntimeFunctionIDMap; + /// Map from ICV kind to the ICV description. EnumeratedArray @@ -395,6 +409,7 @@ SmallVector ArgsTypes({__VA_ARGS__}); \ Function *F = M.getFunction(_Name); \ if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \ + RuntimeFunctionIDMap[F] = _Enum; \ auto &RFI = RFIs[_Enum]; \ RFI.Kind = _Enum; \ RFI.Name = _Name; \ @@ -423,6 +438,391 @@ SmallPtrSetImpl &Kernels; }; +struct KernelInfoState : AbstractState { + KernelInfoState() {} + KernelInfoState(bool BestState) { + if (!BestState) + indicatePessimisticFixpoint(); + } + + /// See AbstractState::isValidState(...) + bool isValidState() const override { return true; } + + /// See AbstractState::isAtFixpoint(...) + bool isAtFixpoint() const override { return IsAtFixpoint; } + + /// See AbstractState::indicatePessimisticFixpoint(...) + ChangeStatus indicatePessimisticFixpoint() override { + IsAtFixpoint = true; + MayReachUnknownParallelRegion = true; + return ChangeStatus::CHANGED; + } + + /// See AbstractState::indicateOptimisticFixpoint(...) + ChangeStatus indicateOptimisticFixpoint() override { + IsAtFixpoint = true; + return ChangeStatus::UNCHANGED; + } + + /// Return the assumed state + KernelInfoState &getAssumed() { return *this; } + const KernelInfoState &getAssumed() const { return *this; } + + bool operator==(const KernelInfoState &RHS) const { + if ((MayReachUnknownParallelRegion != RHS.MayReachUnknownParallelRegion)) + return false; + return + ParallelRegions.size() == RHS.ParallelRegions.size(); + } + + /// Return empty set as the best state of potential values. + static KernelInfoState getBestState() { return KernelInfoState(true); } + + static KernelInfoState getBestState(KernelInfoState &KIS) { + return getBestState(); + } + + /// Return full set as the worst state of potential values. + static KernelInfoState getWorstState() { return KernelInfoState(false); } + + /// "Clamp" this state with \p KIS. + KernelInfoState operator^=(const KernelInfoState &KIS) { + MayReachUnknownParallelRegion |= KIS.MayReachUnknownParallelRegion; + ParallelRegions.insert(KIS.ParallelRegions.begin(), + KIS.ParallelRegions.end()); + return *this; + } + + KernelInfoState operator&=(const KernelInfoState &KIS) { + return (*this ^= KIS); + } + +public: + bool IsAtFixpoint = false; + + SmallSetVector ParallelRegions; + CallBase *KernelInitCB = nullptr; + CallBase *KernelDeinitCB = nullptr; + bool MayReachUnknownParallelRegion = false; +}; + +struct AAKernelInfo : public StateWrapper { + using Base = StateWrapper; + AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {} + + ChangeStatus manifest(Attributor &A) override { + if (!KernelInitCB || !KernelDeinitCB) + return ChangeStatus::UNCHANGED; + + // Callback to check a call instruction. + auto CheckCallInst = [&](Instruction &I) { + auto &CB = cast(I); + CB.addAttribute(AttributeList::FunctionIndex, Attribute::AlwaysInline); + return true; + }; + if (!A.checkForAllCallLikeInstructions(CheckCallInst, *this)) + return indicatePessimisticFixpoint(); + + auto &Ctx = getAnchorValue().getContext(); + const int InitIsSPMDArgNo = 1; + const int InitUseStateMachineArgNo = 2; + ConstantInt *UseStateMachine = dyn_cast( + KernelInitCB->getArgOperand(InitUseStateMachineArgNo)); + ConstantInt *IsSPMD = + dyn_cast(KernelInitCB->getArgOperand(InitIsSPMDArgNo)); + + // If we are stuck with generic mode, try to create a custom device (=GPU) + // state machine which is specialized for the parallel regions that are + // reachable by the kernel. + if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD || + !IsSPMD->isZero()) { + return ChangeStatus::UNCHANGED; + } + + // Indicate we use a custom state machine now. + A.changeUseAfterManifest( + KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), + *ConstantInt::getBool(Ctx, 0)); + + ++NumOpenMPTargetRegionKernelsCustomStateMachine; + + // If we don't need a state machine we are done. + if (!MayReachUnknownParallelRegion && ParallelRegions.empty()) + return ChangeStatus::CHANGED; + + // Create all the blocks: + // + // InitCB = __kmpc_target_init(...) + // bool IsWorker = InitCB >= 0; + // if (IsWorker) { + // SMBeginBB: __kmpc_barrier_simple_spmd(...); + // void *WorkFn; + // bool Active = __kmpc_kernel_parallel(&WorkFn); + // if (!WorkFn) return; + // SMIsActiveCheckBB: if (Active) { + // SMIfCascadeCurrentBB: if (WorkFn == ) + // ParFn0(...); + // SMIfCascadeCurrentBB: else if (WorkFn == ) + // ParFn1(...); + // ... + // SMIfCascadeCurrentBB: else + // ((WorkFnTy*)WorkFn)(...); + // SMEndParallelBB: __kmpc_kernel_end_parallel(...); + // } + // SMDoneBB: __kmpc_barrier_simple_spmd(...); + // goto SMBeginBB; + // } + // UserCodeEntryBB: // user code + // __kmpc_target_deinit(...) + // + Function *Kernel = getAssociatedFunction(); + BasicBlock *InitBB = KernelInitCB->getParent(); + BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock( + KernelInitCB->getNextNode(), "thread.user_code.check"); + BasicBlock *StateMachineBeginBB = BasicBlock::Create( + Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB); + BasicBlock *StateMachineFinishedBB = BasicBlock::Create( + Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB); + BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create( + Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB); + BasicBlock *StateMachineIfCascadeCurrentBB = + BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check", + Kernel, UserCodeEntryBB); + BasicBlock *StateMachineEndParallelBB = + BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end", + Kernel, UserCodeEntryBB); + BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create( + Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB); + + ReturnInst::Create(Ctx, StateMachineFinishedBB); + + InitBB->getTerminator()->eraseFromParent(); + Value *IsWorker = + ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB, + ConstantInt::get(KernelInitCB->getType(), -1), + "thread.is_worker", InitBB); + BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB); + + // Create local storage for the work function pointer. + Type *VoidPtrTy = Type::getInt8PtrTy(Ctx); + AllocaInst *WorkFnAI = new AllocaInst(VoidPtrTy, 0, "worker.work_fn.addr", + &Kernel->getEntryBlock().front()); + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + OMPInfoCache.OMPBuilder.updateToLocation(IRBuilder<>::InsertPoint( + StateMachineBeginBB, StateMachineBeginBB->end())); + + Value *Ident = KernelInitCB->getArgOperand(0); + Value *GTid = KernelInitCB; + + Module &M = *Kernel->getParent(); + FunctionCallee BarrierFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_barrier_simple_spmd); + CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB); + + FunctionCallee KernelParallelFn = + OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_kernel_parallel); + Value *IsActiveWorker = CallInst::Create( + KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB); + Value *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn", + StateMachineBeginBB); + + FunctionType *ParallelRegionFnTy = FunctionType::get( + Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)}, + false); + Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast( + WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast", + StateMachineBeginBB); + + Value *IsDone = ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, + WorkFn, Constant::getNullValue(VoidPtrTy), + "worker.is_done", StateMachineBeginBB); + BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB, + IsDone, StateMachineBeginBB); + + BranchInst::Create(StateMachineIfCascadeCurrentBB, + StateMachineDoneBarrierBB, IsActiveWorker, + StateMachineIsActiveCheckBB); + + Value *ZeroArg = + Constant::getNullValue(ParallelRegionFnTy->getParamType(0)); + + if (!ParallelRegions.empty()) { + + for (int i = 0, e = ParallelRegions.size(); i < e; ++i) { + auto *ParallelRegion = ParallelRegions[i]; + BasicBlock *PRExecuteBB = BasicBlock::Create( + Ctx, "worker_state_machine.parallel_region.execute", Kernel, + StateMachineEndParallelBB); + CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB); + BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB); + + BasicBlock *PRNextBB = BasicBlock::Create( + Ctx, "worker_state_machine.parallel_region.check", Kernel, + StateMachineEndParallelBB); + + // Check if we need to compare the pointer at all or if we can just + // call the parallel region function. + Value *IsPR; + if (i + 1 < e || MayReachUnknownParallelRegion) + IsPR = ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, + WorkFnCast, ParallelRegion, + "worker.check_parallel_region", + StateMachineIfCascadeCurrentBB); + else + IsPR = ConstantInt::getTrue(Ctx); + + BranchInst::Create(PRExecuteBB, PRNextBB, IsPR, + StateMachineIfCascadeCurrentBB); + StateMachineIfCascadeCurrentBB = PRNextBB; + } + } + + if (MayReachUnknownParallelRegion) { + StateMachineIfCascadeCurrentBB->setName( + "worker_state_machine.parallel_region.fallback.execute"); + CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "", + StateMachineIfCascadeCurrentBB); + } + BranchInst::Create(StateMachineEndParallelBB, + StateMachineIfCascadeCurrentBB); + + CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction( + M, OMPRTL___kmpc_kernel_end_parallel), + {}, "", StateMachineEndParallelBB); + BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB); + + CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB); + BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB); + + return ChangeStatus::CHANGED; + } + + /// Statistics are tracked as part of manifest for now. + void trackStatistics() const override {} + + /// See AbstractAttribute::getAsStr() + const std::string getAsStr() const override { + return std::string("#PR: ") + std::to_string(ParallelRegions.size()) + (MayReachUnknownParallelRegion ? " + Unknown PR" : ""); + } + + /// Create an abstract attribute biew for the position \p IRP. + static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A); + + /// See AbstractAttribute::getName() + const std::string getName() const override { return "AAKernelInfo"; } + + /// See AbstractAttribute::getIdAddr() + const char *getIdAddr() const override { return &ID; } + + /// This function should return true if the type of the \p AA is AAKernelInfo + static bool classof(const AbstractAttribute *AA) { + return (AA->getIdAddr() == &ID); + } + + static const char ID; +}; + +struct AAKernelInfoFunction : AAKernelInfo { + AAKernelInfoFunction(const IRPosition &IRP, Attributor &A) + : AAKernelInfo(IRP, A) {} + + void initialize(Attributor &A) override { + } + + ChangeStatus updateImpl(Attributor &A) override { + KernelInfoState StateBefore = getState(); + + // Callback to check a call instruction. + auto CheckCallInst = [&](Instruction &I) { + auto &CB = cast(I); + if (!CB.mayWriteToMemory()) + return true; + if (isa(CB)) + return true; + if (CB.hasFnAttr("no_openmp")) + return true; + + Function *Callee = CB.getCalledFunction(); + auto &OMPInfoCache = static_cast(A.getInfoCache()); + const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee); + if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { + auto &CBAA = A.getAAFor( + *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED); + if (CBAA.getState().isValidState()) { + getState() ^= CBAA.getState(); + } else { + MayReachUnknownParallelRegion = true; + } + return true; + } + + const unsigned int WrapperFunctionArgNo = 6; + RuntimeFunction RF = It->getSecond(); + switch (RF) { + case OMPRTL___kmpc_target_init: + if (KernelInitCB && KernelInitCB != &CB) + return false; + KernelInitCB = &CB; + return true; + case OMPRTL___kmpc_target_deinit: + if (KernelDeinitCB && KernelDeinitCB != &CB) + return false; + KernelDeinitCB = &CB; + return true; + case OMPRTL___kmpc_parallel_51: + if (auto *ParallelRegion = dyn_cast( + CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) { + ParallelRegions.insert(ParallelRegion); + return true; + } + MayReachUnknownParallelRegion = true; + return true; + case OMPRTL___kmpc_omp_task: + MayReachUnknownParallelRegion = true; + return true; + default: + break; + } + return true; + }; + if (!A.checkForAllCallLikeInstructions(CheckCallInst, *this)) + return indicatePessimisticFixpoint(); + + return StateBefore == getState() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + } +}; + +struct AAKernelInfoCallSite : AAKernelInfo { + AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A) + : AAKernelInfo(IRP, A) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + AAKernelInfo::initialize(A); + Function *F = getAssociatedFunction(); + if (!F || F->isDeclaration()) + indicatePessimisticFixpoint(); + } + + ChangeStatus updateImpl(Attributor &A) override { + // TODO: Once we have call site specific value information we can provide + // call site specific liveness information and then it makes + // sense to specialize attributes for call sites arguments instead of + // redirecting requests to the callee argument. + Function *F = getAssociatedFunction(); + const IRPosition &FnPos = IRPosition::function(*F); + auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); + if (getState() == FnAA.getState()) + return ChangeStatus::UNCHANGED; + getState() = FnAA.getState(); + return ChangeStatus::CHANGED; + } +}; + /// Used to map the values physically (in the IR) stored in an offload /// array, to a vector in memory. struct OffloadArray { @@ -547,14 +947,14 @@ if (PrintOpenMPKernels) printKernels(); - Changed |= rewriteDeviceCodeStateMachine(); - - Changed |= runAttributor(); + Changed |= runAttributor(IsModulePass); // Recollect uses, in case Attributor deleted any. OMPInfoCache.recollectUses(); Changed |= deleteParallelRegions(); + Changed |= rewriteDeviceCodeStateMachine(); + if (HideMemoryTransferLatency) Changed |= hideMemTransfersLatency(); Changed |= deduplicateRuntimeCalls(); @@ -1599,11 +1999,11 @@ Attributor &A; /// Helper function to run Attributor on SCC. - bool runAttributor() { + bool runAttributor(bool IsModulePass) { if (SCC.empty()) return false; - registerAAs(); + registerAAs(IsModulePass); ChangeStatus Changed = A.run(); @@ -1615,7 +2015,7 @@ /// Populate the Attributor with abstract attribute opportunities in the /// function. - void registerAAs() { + void registerAAs(bool IsModulePass) { if (SCC.empty()) return; @@ -1644,6 +2044,10 @@ if (!F.isDeclaration()) A.getOrCreateAAFor(IRPosition::function(F)); } + if (IsModulePass) { + for (Function *Kernel : OMPInfoCache.Kernels) + A.getOrCreateAAFor(IRPosition::function(*Kernel)); + } } }; @@ -1780,7 +2184,7 @@ // TODO: Checking the number of uses is not a necessary restriction and // should be lifted. if (UnknownUse || NumDirectCalls != 1 || - ToBeReplacedStateMachineUses.size() != 2) { + ToBeReplacedStateMachineUses.size() > 2) { { auto Remark = [&](OptimizationRemark OR) { return OR << "Parallel region is used in " @@ -2411,6 +2815,7 @@ } // namespace const char AAICVTracker::ID = 0; +const char AAKernelInfo::ID = 0; const char AAExecutionDomain::ID = 0; AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP, @@ -2460,6 +2865,28 @@ return *AA; } +AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP, + Attributor &A) { + AAKernelInfo *AA = nullptr; + switch (IRP.getPositionKind()) { + case IRPosition::IRP_INVALID: + case IRPosition::IRP_FLOAT: + case IRPosition::IRP_ARGUMENT: + case IRPosition::IRP_RETURNED: + case IRPosition::IRP_CALL_SITE_RETURNED: + case IRPosition::IRP_CALL_SITE_ARGUMENT: + llvm_unreachable("KernelInfo can only be created for function position!"); + case IRPosition::IRP_CALL_SITE: + AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A); + break; + case IRPosition::IRP_FUNCTION: + AA = new (A.Allocator) AAKernelInfoFunction(IRP, A); + break; + } + + return *AA; +} + PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) { if (!containsOpenMP(M, OMPInModule)) return PreservedAnalyses::all(); @@ -2492,7 +2919,7 @@ OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, OMPInModule.getKernels()); - Attributor A(Functions, InfoCache, CGUpdater); + Attributor A(Functions, InfoCache, CGUpdater, nullptr, false); OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); bool Changed = OMPOpt.run(true); @@ -2547,7 +2974,7 @@ OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator, /*CGSCC*/ Functions, OMPInModule.getKernels()); - Attributor A(Functions, InfoCache, CGUpdater); + Attributor A(Functions, InfoCache, CGUpdater, nullptr, false); OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); bool Changed = OMPOpt.run(false); @@ -2622,7 +3049,7 @@ *(Functions.back()->getParent()), AG, Allocator, /*CGSCC*/ Functions, OMPInModule.getKernels()); - Attributor A(Functions, InfoCache, CGUpdater); + Attributor A(Functions, InfoCache, CGUpdater, nullptr, false); OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A); return OMPOpt.run(false);