Index: llvm/lib/Transforms/IPO/OpenMPOpt.cpp =================================================================== --- llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -26,6 +26,8 @@ #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/Attributor.h" #include "llvm/Transforms/Utils/CallGraphUpdater.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/ValueTracking.h" using namespace llvm; using namespace omp; @@ -185,6 +187,28 @@ DenseMap> UsesMap; }; + /// Used to store/manipualte information about a runtime call that involves + /// host to device memory offloading. + struct MemoryTransfer { + struct OffloadArray { + SmallVector LastAccesses; + SmallVector StoredAddresses; + }; + + CallBase *RuntimeCall; + MemorySSA &MSSA; + std::unique_ptr BasePtrs; + std::unique_ptr Ptrs; + std::unique_ptr Sizes; + + MemoryTransfer(CallBase *RuntimeCall, MemorySSA &MSSA) : + RuntimeCall{RuntimeCall}, MSSA{MSSA}, + BasePtrs {std::make_unique()}, + Ptrs {std::make_unique()}, + Sizes {std::make_unique()} + {} + }; + /// The slice of the module we are allowed to look at. SmallPtrSetImpl &ModuleSlice; @@ -367,6 +391,7 @@ Changed |= deduplicateRuntimeCalls(); Changed |= deleteParallelRegions(); + Changed |= hideMemTransfersLatency(); return Changed; } @@ -394,6 +419,9 @@ } private: + /// Helper types. + using MemoryTransfer = OMPInformationCache::MemoryTransfer; + /// Try to delete parallel regions if possible. bool deleteParallelRegions() { const unsigned CallbackCalleeOperand = 2; @@ -489,6 +517,173 @@ return Changed; } + /// Tries to hide the latency of runtime calls that involve host to + /// device memory transfers. + bool hideMemTransfersLatency() { + OMPInformationCache::RuntimeFunctionInfo &RFI = + OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin]; + + bool Changed = false; + auto SplitDataTransfer = [&] (Use &U, Function &Decl) { + auto *RTCall = getCallIfRegularCall(U, &RFI); + if (!RTCall) + return false; + + auto *MSSAResult = + OMPInfoCache.getAnalysisResultForFunction( + *RTCall->getCaller()); + if (!MSSAResult) + return false; + + auto &MSSA = MSSAResult->getMSSA(); + MemoryTransfer MT(RTCall, MSSA); + Changed = splitMemoryTransfer(MT); + return Changed; + }; + + RFI.foreachUse(SplitDataTransfer); + return Changed; + } + + bool splitMemoryTransfer(MemoryTransfer &MT) { + bool Changed = false; + bool Success = getValuesInOfflArrays(MT); + if (!Success) { + LLVM_DEBUG(dbgs() << TAG << "Couldn't get offload arrays in call to " + << MT.RuntimeCall->getName() << " in function " + << MT.RuntimeCall->getCaller()->getName() << "\n"); + return false; + } + + return Changed; + } + + bool getValuesInOfflArrays(MemoryTransfer &MT) { + auto *RuntimeCall = MT.RuntimeCall; + auto *BasePtrsArg = RuntimeCall->arg_begin() + 2; // **offload_baseptrs. + auto *PtrsArg = RuntimeCall->arg_begin() + 3; // **offload_ptrs. + auto *SizesArg = RuntimeCall->arg_begin() + 4; // **offload_sizes. + auto DL = OMPInfoCache.getDL(); + + // Get values stored in **offload_baseptrs. + auto *V = GetUnderlyingObject(BasePtrsArg->get(), DL); + bool Success = getValuesInOfflArray(V, *MT.BasePtrs, RuntimeCall); + if (!Success) { + LLVM_DEBUG(dbgs() << TAG << "Couldn't get offload_baseptrs in call to " + << RuntimeCall->getName() << " in function " + << RuntimeCall->getCaller()->getName() << "\n"); + return false; + } + + // Get values stored in **offload_ptrs. + V = GetUnderlyingObject(PtrsArg->get(), DL); + Success = getValuesInOfflArray(V, *MT.Ptrs, RuntimeCall); + if (!Success) { + LLVM_DEBUG(dbgs() << TAG << "Couldn't get offload_ptrs in call to " + << RuntimeCall->getName() << " in function " + << RuntimeCall->getCaller()->getName() << "\n"); + return false; + } + + // Get values stored in **offload_sizes. + V = GetUnderlyingObject(SizesArg->get(), DL); + Success = getValuesInOfflArray(V, *MT.Sizes, RuntimeCall); + if (!Success) { + LLVM_DEBUG(dbgs() << TAG << "Couldn't get offload_sizes in call to " + << RuntimeCall->getName() << " in function " + << RuntimeCall->getCaller()->getName() << "\n"); + return false; + } + + return true; + } + + /// Gets the values stored in \p OfflArray and stores them in \p Dst. + /// \p Before serves as a lower bound, so don't look at accesses after that. + bool getValuesInOfflArray(Value *OfflArray, + MemoryTransfer::OffloadArray &Dst, + User *Before = nullptr) { + assert(OfflArray && "Can't get values in nullptr!"); + + if (!isa(OfflArray)) { + LLVM_DEBUG(dbgs() << TAG << "Only alloca arrays supported.\n"); + return false; + } + + auto *ArrayAlloc = cast(OfflArray); + const uint64_t NumValues = + ArrayAlloc->getAllocatedType()->getArrayNumElements(); + + auto &LastAccesses = Dst.LastAccesses; + auto &StoredAddresses = Dst.StoredAddresses; + LastAccesses.assign(NumValues, nullptr); + StoredAddresses.assign(NumValues, nullptr); + + // Get last accesses to the array right before Before. + for (auto *Usr : OfflArray->users()) { + // If reached lower limit. + if (Before && Usr == Before) + break; + + auto *Access = dyn_cast(Usr); + if (!Access) + continue; + + auto *ArrayIdx = Access->idx_begin() + 1; + if (ArrayIdx == Access->idx_end()) + continue; + + const uint64_t IdxLiteral = getIntLiteral(ArrayIdx->get()); + LastAccesses[IdxLiteral] = Usr; + } + + // Get stored addresses. + for (unsigned It = 0; It < NumValues; ++It) { + auto *Accs = LastAccesses[It]; + auto AccsUsr = Accs->user_begin(); + if (AccsUsr == Accs->user_end()) { + LLVM_DEBUG(dbgs() << TAG << "Useless access to offload array.\n"); + return false; + } + + auto *I = cast(*AccsUsr); + if (I->isCast()) + AccsUsr = I->user_begin(); + + if (!isa(*AccsUsr)) { + LLVM_DEBUG(dbgs() << TAG << "Unrecognized access pattern.\n"); + return false; + } + + StoredAddresses[It] = + GetUnderlyingObject(AccsUsr->getOperand(0), OMPInfoCache.getDL()); + } + + if (!isFilled(Dst)) { + LLVM_DEBUG(dbgs() << TAG << "Didn't get all values in offload array.\n"); + return false; + } + + return true; + } + + bool isFilled(MemoryTransfer::OffloadArray &OA) { + for (auto *Acc : OA.LastAccesses) + if (!Acc) + return false; + + for (auto *Addr : OA.StoredAddresses) + if (!Addr) + return false; + return true; + } + + /// Returns the integer representation of \p V. + static uint64_t getIntLiteral(const Value *V) { + assert(V && "Getting Integer value of nullptr"); + return (dyn_cast(V))->getZExtValue(); + } + static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, bool GlobalOnly, bool &SingleChoice) { if (CurrentIdent == NextIdent)