Index: llvm/include/llvm/Transforms/IPO/OpenMPOpt.h =================================================================== --- llvm/include/llvm/Transforms/IPO/OpenMPOpt.h +++ llvm/include/llvm/Transforms/IPO/OpenMPOpt.h @@ -21,7 +21,6 @@ #include "llvm/Analysis/MemorySSA.h" namespace llvm { - namespace omp { using namespace types; @@ -133,6 +132,30 @@ DenseMap> UsesMap; }; + /// Used to store information about a runtime call that involves + /// host to device memory offloading. + struct MemoryTransfer { + struct OffloadArray { + SmallVector LastAccesses; + SmallVector StoredAddresses; + }; + + CallBase *RuntimeCall; + MemorySSA &MSSA; + std::unique_ptr BasePtrs; + std::unique_ptr Ptrs; + std::unique_ptr Sizes; + + MemoryTransfer(CallBase *RuntimeCall, MemorySSA &MSSA) : + RuntimeCall{RuntimeCall}, MSSA{MSSA}, + BasePtrs {std::make_unique()}, + Ptrs {std::make_unique()}, + Sizes {std::make_unique()} + {} + + static bool isFilled(OffloadArray &OA); + }; + /// The slice of the module we are allowed to look at. SmallPtrSetImpl &ModuleSlice; @@ -166,6 +189,7 @@ struct OpenMPOpt { + using MemoryTransfer = OMPInformationCache::MemoryTransfer; using OptimizationRemarkGetter = function_ref; @@ -178,6 +202,10 @@ /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice. bool run(); + /// Gets the values stored in the offload arrays specified by \p MT. Returns + /// false if some of the values couldn't be found. + bool getValuesInOfflArrays(MemoryTransfer &MT); + /// Return the call if \p U is a callee use in a regular call. If \p RFI is /// given it has to be the callee or a nullptr is returned. static CallInst *getCallIfRegularCall( @@ -195,6 +223,12 @@ /// Try to eliminiate runtime calls by reusing existing ones. bool deduplicateRuntimeCalls(); + /// Splits a runtime call that involves a host to device transfer into its "" + bool splitMemoryTransfer(MemoryTransfer &MT); + + bool getValuesInOfflArray(Value *OfflArray, MemoryTransfer::OffloadArray &Dst, + User *Before); + static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, bool GlobalOnly, bool &SingleChoice); @@ -213,6 +247,10 @@ OMPInformationCache::RuntimeFunctionInfo &RFI, Value *ReplVal = nullptr); + /// Tries to hide the latency of runtime calls that involve host to + /// device memory transfers. + bool hideMemTransfersLatency(); + /// Collect arguments that represent the global thread id in \p GTIdArgs. void collectGlobalThreadIdArguments(SmallSetVector >IdArgs); @@ -263,6 +301,12 @@ /// Populate the Attributor with abstract attribute opportunities in the /// function. void registerAAs(); + + /// Returns the integer representation of \p V. + static uint64_t getIntLiteral(const Value *V) { + assert(V && "Getting Integer value of nullptr"); + return (dyn_cast(V))->getZExtValue(); + } }; /// Helper to remember if the module contains OpenMP (runtime calls), to be used Index: llvm/lib/Transforms/IPO/OpenMPOpt.cpp =================================================================== --- llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -26,6 +26,8 @@ #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/Attributor.h" #include "llvm/Transforms/Utils/CallGraphUpdater.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/MemorySSA.h" using namespace llvm; using namespace omp; @@ -223,7 +225,7 @@ if (F->arg_size() != RTFArgTypes.size()) return false; - auto RTFTyIt = RTFArgTypes.begin(); + auto *RTFTyIt = RTFArgTypes.begin(); for (Argument &Arg : F->args()) { if (Arg.getType() != *RTFTyIt) return false; @@ -234,6 +236,17 @@ return true; } +bool OMPInformationCache::MemoryTransfer::isFilled(OffloadArray &OA) { + for (auto *Acc : OA.LastAccesses) + if (!Acc) + return false; + + for (auto *Addr : OA.StoredAddresses) + if (!Addr) + return false; + return true; +} + //===----------------------------------------------------------------------===// // Declarations and definitions of AAICVTracker. //===----------------------------------------------------------------------===// @@ -443,6 +456,7 @@ Changed |= runAttributor(); Changed |= deduplicateRuntimeCalls(); Changed |= deleteParallelRegions(); + Changed |= hideMemTransfersLatency(); return Changed; } @@ -558,6 +572,155 @@ return Changed; } +bool OpenMPOpt::hideMemTransfersLatency() { + OMPInformationCache::RuntimeFunctionInfo &RFI = + OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin]; + + bool Changed = false; + auto SplitDataTransfer = [&] (Use &U, Function &Decl) { + auto *RTCall = getCallIfRegularCall(U, &RFI); + if (!RTCall) + return false; + + auto *MSSAResult = + OMPInfoCache.getAnalysisResultForFunction( + *RTCall->getCaller()); + if (!MSSAResult) + return false; + + auto &MSSA = MSSAResult->getMSSA(); + MemoryTransfer MT(RTCall, MSSA); + Changed = splitMemoryTransfer(MT); + return Changed; + }; + + RFI.foreachUse(SplitDataTransfer); + return Changed; +} + +bool OpenMPOpt::splitMemoryTransfer(MemoryTransfer &MT) { + bool Changed = false; + bool Success = getValuesInOfflArrays(MT); + if (!Success) { + LLVM_DEBUG(dbgs() << TAG << "Couldn't get offload arrays in call to " + << MT.RuntimeCall->getName() << " in function " + << MT.RuntimeCall->getCaller()->getName() << "\n"); + return false; + } + + return Changed; +} + +bool OpenMPOpt::getValuesInOfflArrays(MemoryTransfer &MT) { + auto *RuntimeCall = MT.RuntimeCall; + const unsigned BasePtrs = 2; // **offload_baseptrs. + const unsigned Ptrs = 3; // **offload_ptrs. + const unsigned Sizes = 4; // **offload_sizes. + auto *BasePtrsArg = RuntimeCall->arg_begin() + BasePtrs; + auto *PtrsArg = RuntimeCall->arg_begin() + Ptrs; + auto *SizesArg = RuntimeCall->arg_begin() + Sizes; + const auto &DL = OMPInfoCache.getDL(); + + // Get values stored in **offload_baseptrs. + auto *V = GetUnderlyingObject(BasePtrsArg->get(), DL); + bool Success = getValuesInOfflArray(V, *MT.BasePtrs, RuntimeCall); + if (!Success) { + LLVM_DEBUG(dbgs() << TAG << "Couldn't get offload_baseptrs in call to " + << RuntimeCall->getName() << " in function " + << RuntimeCall->getCaller()->getName() << "\n"); + return false; + } + + // Get values stored in **offload_ptrs. + V = GetUnderlyingObject(PtrsArg->get(), DL); + Success = getValuesInOfflArray(V, *MT.Ptrs, RuntimeCall); + if (!Success) { + LLVM_DEBUG(dbgs() << TAG << "Couldn't get offload_ptrs in call to " + << RuntimeCall->getName() << " in function " + << RuntimeCall->getCaller()->getName() << "\n"); + return false; + } + + // Get values stored in **offload_sizes. + V = GetUnderlyingObject(SizesArg->get(), DL); + Success = getValuesInOfflArray(V, *MT.Sizes, RuntimeCall); + if (!Success) { + LLVM_DEBUG(dbgs() << TAG << "Couldn't get offload_sizes in call to " + << RuntimeCall->getName() << " in function " + << RuntimeCall->getCaller()->getName() << "\n"); + return false; + } + return true; +} + +/// Gets the values stored in \p OfflArray and stores them in \p Dst. +/// \p Before serves as a lower bound, so don't look at accesses after that. +bool OpenMPOpt::getValuesInOfflArray( + Value *OfflArray, MemoryTransfer::OffloadArray &Dst, User *Before) { + assert(OfflArray && "Can't get values in nullptr!"); + + if (!isa(OfflArray)) { + LLVM_DEBUG(dbgs() << TAG << "Only alloca arrays supported.\n"); + return false; + } + + auto *ArrayAlloc = cast(OfflArray); + const uint64_t NumValues = + ArrayAlloc->getAllocatedType()->getArrayNumElements(); + + auto &LastAccesses = Dst.LastAccesses; + auto &StoredAddresses = Dst.StoredAddresses; + LastAccesses.assign(NumValues, nullptr); + StoredAddresses.assign(NumValues, nullptr); + + // Get last accesses to the array right before Before. + for (auto *Usr : OfflArray->users()) { + // If reached lower limit. + if (Before && Usr == Before) + break; + + auto *Access = dyn_cast(Usr); + if (!Access) + continue; + + auto *ArrayIdx = Access->idx_begin() + 1; + if (ArrayIdx == Access->idx_end()) + continue; + + const uint64_t IdxLiteral = getIntLiteral(ArrayIdx->get()); + LastAccesses[IdxLiteral] = Usr; + } + + // Get stored addresses. + for (unsigned It = 0; It < NumValues; ++It) { + auto *Accs = LastAccesses[It]; + auto AccsUsr = Accs->user_begin(); + if (AccsUsr == Accs->user_end()) { + LLVM_DEBUG(dbgs() << TAG << "Useless access to offload array.\n"); + return false; + } + + auto *I = cast(*AccsUsr); + if (I->isCast()) + AccsUsr = I->user_begin(); + + if (!isa(*AccsUsr)) { + LLVM_DEBUG(dbgs() << TAG << "Unrecognized access pattern.\n"); + return false; + } + + StoredAddresses[It] = + GetUnderlyingObject(AccsUsr->getOperand(0), OMPInfoCache.getDL()); + } + + if (!MemoryTransfer::isFilled(Dst)) { + LLVM_DEBUG(dbgs() << TAG << "Didn't get all values in offload array.\n"); + return false; + } + + return true; +} + Value *OpenMPOpt::combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, bool GlobalOnly, bool &SingleChoice) { if (CurrentIdent == NextIdent)