Index: llvm/lib/Transforms/IPO/OpenMPOpt.cpp =================================================================== --- llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -26,6 +26,8 @@ #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/Attributor.h" #include "llvm/Transforms/Utils/CallGraphUpdater.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/ValueTracking.h" using namespace llvm; using namespace omp; @@ -185,6 +187,28 @@ DenseMap> UsesMap; }; + /// Used to store/manipualte information about a runtime call that involves + /// host to device memory offloading. + struct MemoryTransfer { + struct OffloadArray { + SmallVector LastAccesses; + SmallVector StoredAddresses; + }; + + CallBase *RuntimeCall; + MemorySSA &MSSA; + std::unique_ptr BasePtrs; + std::unique_ptr Ptrs; + std::unique_ptr Sizes; + + MemoryTransfer(CallBase *RuntimeCall, MemorySSA &MSSA) : + RuntimeCall{RuntimeCall}, MSSA{MSSA}, + BasePtrs {std::make_unique()}, + Ptrs {std::make_unique()}, + Sizes {std::make_unique()} + {} + }; + /// The slice of the module we are allowed to look at. SmallPtrSetImpl &ModuleSlice; @@ -329,12 +353,14 @@ using OptimizationRemarkGetter = function_ref; + using MemorySSAGetter = function_ref; OpenMPOpt(SmallVectorImpl &SCC, CallGraphUpdater &CGUpdater, - OptimizationRemarkGetter OREGetter, + OptimizationRemarkGetter OREGetter, MemorySSAGetter MSSAGetter, OMPInformationCache &OMPInfoCache) : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater), - OREGetter(OREGetter), OMPInfoCache(OMPInfoCache) {} + OREGetter(OREGetter), MSSAGetter(MSSAGetter), + OMPInfoCache(OMPInfoCache) {} /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice. bool run() { @@ -367,6 +393,7 @@ Changed |= deduplicateRuntimeCalls(); Changed |= deleteParallelRegions(); + Changed |= hideMemTransfersLatency(); return Changed; } @@ -394,6 +421,9 @@ } private: + /// Helper types. + using MemoryTransfer = OMPInformationCache::MemoryTransfer; + /// Try to delete parallel regions if possible. bool deleteParallelRegions() { const unsigned CallbackCalleeOperand = 2; @@ -489,6 +519,167 @@ return Changed; } + /// Tries to hide the latency of runtime calls that involve host to + /// device memory transfers. + bool hideMemTransfersLatency() { + OMPInformationCache::RuntimeFunctionInfo &RFI = + OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin]; + + bool Changed = false; + auto SplitDataTransfer = [&] (Use &U, Function &Decl) { + if (CallInst *RT = getCallIfRegularCall(U, &RFI)) { + if (auto *MSSA = MSSAGetter(RT->getCaller())) { + MemoryTransfer MT(RT, *MSSA); + Changed = splitMemoryTransfer(MT); + } + } + return Changed; + }; + + RFI.foreachUse(SplitDataTransfer); + return Changed; + } + + bool splitMemoryTransfer(MemoryTransfer &MT) { + bool Changed = false; + unsigned Status = EXIT_SUCCESS; + + Status = getValuesInOfflArrays(MT); + if (Status == EXIT_FAILURE) { + LLVM_DEBUG(dbgs() << TAG << "Couldn't get offload arrays in call to " + << MT.RuntimeCall->getName() << " in function " + << MT.RuntimeCall->getCaller()->getName() << "\n"); + return Changed; + } + + return Changed; + } + + unsigned getValuesInOfflArrays(MemoryTransfer &MT) { + auto *RuntimeCall = MT.RuntimeCall; + auto *BasePtrsArg = RuntimeCall->arg_begin() + 2; // **offload_baseptrs. + auto *PtrsArg = RuntimeCall->arg_begin() + 3; // **offload_ptrs. + auto *SizesArg = RuntimeCall->arg_begin() + 4; // **offload_sizes. + auto DL = OMPInfoCache.getDL(); + + // Get values stored in **offload_baseptrs. + auto *V = GetUnderlyingObject(BasePtrsArg->get(), DL); + unsigned Status = getValuesInOfflArray(V, *MT.BasePtrs, RuntimeCall); + if (Status == EXIT_FAILURE) { + LLVM_DEBUG(dbgs() << TAG << "Couldn't get offload_baseptrs in call to " + << RuntimeCall->getName() << " in function " + << RuntimeCall->getCaller()->getName() << "\n"); + return EXIT_FAILURE; + } + + // Get values stored in **offload_ptrs. + V = GetUnderlyingObject(PtrsArg->get(), DL); + Status = getValuesInOfflArray(V, *MT.Ptrs, RuntimeCall); + if (Status == EXIT_FAILURE) { + LLVM_DEBUG(dbgs() << TAG << "Couldn't get offload_ptrs in call to " + << RuntimeCall->getName() << " in function " + << RuntimeCall->getCaller()->getName() << "\n"); + return EXIT_FAILURE; + } + + // Get values stored in **offload_sizes. + V = GetUnderlyingObject(SizesArg->get(), DL); + Status = getValuesInOfflArray(V, *MT.Sizes, RuntimeCall); + if (Status == EXIT_FAILURE) { + LLVM_DEBUG(dbgs() << TAG << "Couldn't get offload_sizes in call to " + << RuntimeCall->getName() << " in function " + << RuntimeCall->getCaller()->getName() << "\n"); + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; + } + + /// Gets the values stored in \p OfflArray and stores them in \p Dst. + /// \p Before serves as a lower bound, so don't look at accesses after that. + unsigned getValuesInOfflArray(Value *OfflArray, + MemoryTransfer::OffloadArray &Dst, + User *Before = nullptr) { + assert(OfflArray && "Can't get values in nullptr!"); + + if (!isa(OfflArray)) { + LLVM_DEBUG(dbgs() << TAG << "Only alloca arrays supported.\n"); + return EXIT_FAILURE; + } + + auto *ArrayAlloc = cast(OfflArray); + const uint64_t NumValues = + ArrayAlloc->getAllocatedType()->getArrayNumElements(); + + auto &LastAccesses = Dst.LastAccesses; + auto &StoredAddresses = Dst.StoredAddresses; + LastAccesses.assign(NumValues, nullptr); + StoredAddresses.assign(NumValues, nullptr); + + // Get last accesses to the array right before Before. + for (auto *Usr : OfflArray->users()) { + // If reached lower limit. + if (Before && Usr == Before) + break; + + auto *Access = dyn_cast(Usr); + if (!Access) + continue; + + auto *ArrayIdx = Access->idx_begin() + 1; + if (ArrayIdx == Access->idx_end()) + continue; + + const uint64_t IdxLiteral = getIntLiteral(ArrayIdx->get()); + LastAccesses[IdxLiteral] = Usr; + } + + // Get stored addresses. + for (unsigned It = 0; It < NumValues; ++It) { + auto *Accs = LastAccesses[It]; + auto AccsUsr = Accs->user_begin(); + if (AccsUsr == Accs->user_end()) { + LLVM_DEBUG(dbgs() << TAG << "Useless access to offload array.\n"); + return EXIT_FAILURE; + } + + auto *I = cast(*AccsUsr); + if (I->isCast()) + AccsUsr = I->user_begin(); + + if (!isa(*AccsUsr)) { + LLVM_DEBUG(dbgs() << TAG << "Unrecognized access pattern.\n"); + return EXIT_FAILURE; + } + + StoredAddresses[It] = AccsUsr->getOperand(0); + } + + if (!isFilled(Dst)) { + LLVM_DEBUG(dbgs() << TAG << "Couldn't fill offload array.\n"); + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; + } + + bool isFilled(MemoryTransfer::OffloadArray OA) { + for (auto *Acc : OA.LastAccesses) + if (!Acc) + return false; + + for (auto *Addr : OA.StoredAddresses) + if (!Addr) + return false; + return true; + } + + /// Returns the integer representation of \p V. + static uint64_t getIntLiteral(const Value *V) { + assert(V && "Getting Integer value of nullptr"); + return (dyn_cast(V))->getZExtValue(); + } + static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, bool GlobalOnly, bool &SingleChoice) { if (CurrentIdent == NextIdent) @@ -724,6 +915,9 @@ /// Callback to get an OptimizationRemarkEmitter from a Function * OptimizationRemarkGetter OREGetter; + /// Used to get the MemorySSA analysis of a specified function. + MemorySSAGetter MSSAGetter; + /// OpenMP-specific information cache. Also Used for Attributor runs. OMPInformationCache &OMPInfoCache; }; @@ -757,6 +951,12 @@ return FAM.getResult(*F); }; + auto MSSAGetter = [&C, &CG, &AM](Function *F) -> MemorySSA * { + FunctionAnalysisManager &FAM = + AM.getResult(C, CG).getManager(); + return &(FAM.getResult(*F).getMSSA()); + }; + CallGraphUpdater CGUpdater; CGUpdater.initialize(CG, C, AM, UR); @@ -766,7 +966,7 @@ /*CGSCC*/ &Functions, ModuleSlice); // TODO: Compute the module slice we are allowed to look at. - OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache); + OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, MSSAGetter, InfoCache); bool Changed = OMPOpt.run(); (void)Changed; return PreservedAnalyses::all(); @@ -823,6 +1023,10 @@ return *ORE; }; + auto MSSAGetter = [](Function *F) -> MemorySSA * { + return nullptr; + }; + AnalysisGetter AG; SetVector Functions(SCC.begin(), SCC.end()); BumpPtrAllocator Allocator; @@ -831,7 +1035,7 @@ /*CGSCC*/ &Functions, ModuleSlice); // TODO: Compute the module slice we are allowed to look at. - OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache); + OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, MSSAGetter, InfoCache); return OMPOpt.run(); }