Index: llvm/include/llvm/Frontend/OpenMP/OMPKinds.def =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -286,6 +286,7 @@ OMP_STRUCT_TYPE(VarName, "struct." #Name, __VA_ARGS__) __OMP_STRUCT_TYPE(Ident, ident_t, Int32, Int32, Int32, Int32, Int8Ptr) +__OMP_STRUCT_TYPE(AsyncInfo, __tgt_async_info, Int8Ptr) #undef __OMP_STRUCT_TYPE #undef OMP_STRUCT_TYPE @@ -570,6 +571,9 @@ VoidPtrPtr, Int64Ptr, Int64Ptr) __OMP_RTL(__tgt_target_data_begin_nowait, false, Void, Int64, Int32, VoidPtrPtr, VoidPtrPtr, Int64Ptr, Int64Ptr) +__OMP_RTL(__tgt_target_data_begin_issue, false, AsyncInfo, Int64, Int32, VoidPtrPtr, + VoidPtrPtr, Int64Ptr, Int64Ptr) +__OMP_RTL(__tgt_target_data_begin_wait, false, Void, Int64, AsyncInfo) __OMP_RTL(__tgt_target_data_end, false, Void, Int64, Int32, VoidPtrPtr, VoidPtrPtr, Int64Ptr, Int64Ptr) __OMP_RTL(__tgt_target_data_end_nowait, false, Void, Int64, Int32, VoidPtrPtr, Index: llvm/include/llvm/Transforms/IPO/OpenMPOpt.h =================================================================== --- llvm/include/llvm/Transforms/IPO/OpenMPOpt.h +++ llvm/include/llvm/Transforms/IPO/OpenMPOpt.h @@ -180,8 +180,11 @@ /// These help mapping the values in offload_baseptrs, offload_ptrs, and /// offload_sizes, respectively. + const unsigned BasePtrsArgNum = 2; std::unique_ptr BasePtrs = nullptr; + const unsigned PtrsArgNum = 3; std::unique_ptr Ptrs = nullptr; + const unsigned SizesArgNum = 4; std::unique_ptr Sizes = nullptr; /// Set of instructions that compose the argument setup for the call @@ -207,6 +210,11 @@ /// offload arrays. bool mayBeModifiedBy(Instruction *I); + /// Splits this object into its "issue" and "wait" corresponding runtime + /// calls. The "issue" is moved after \p After and the "wait" is moved + /// before \p Before. + bool split(Instruction *After, Instruction *Before); + private: /// Gets the setup instructions for each of the values in \p OA. These /// instructions are stored into Issue. Index: llvm/lib/Transforms/IPO/OpenMPOpt.cpp =================================================================== --- llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -44,6 +44,11 @@ static cl::opt PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden); +static cl::opt SplitMemoryTransfers( + "openmp-split-memtransfers", + cl::desc("Tries to hide the latency of host to device memory transfers"), + cl::Hidden, cl::init(false)); + STATISTIC(NumOpenMPRuntimeCallsDeduplicated, "Number of OpenMP runtime calls deduplicated"); STATISTIC(NumOpenMPParallelRegionsDeleted, @@ -337,6 +342,9 @@ << RuntimeCall->getCaller()->getName() << "\n"); return false; } + auto *BasePtrsGEP = + cast(RuntimeCall->getArgOperand(BasePtrsArgNum)); + Issue.insert(BasePtrsGEP); Success = getSetupInstructions(Ptrs); if (!Success) { @@ -346,6 +354,9 @@ << RuntimeCall->getCaller()->getName() << "\n"); return false; } + auto *PtrsGEP = + cast(RuntimeCall->getArgOperand(PtrsArgNum)); + Issue.insert(PtrsGEP); if (Sizes) { Success = getSetupInstructions(Sizes); @@ -356,6 +367,9 @@ << RuntimeCall->getCaller()->getName() << "\n"); return false; } + auto *SizesGEP = + cast(RuntimeCall->getArgOperand(SizesArgNum)); + Issue.insert(SizesGEP); } return true; @@ -495,6 +509,49 @@ return true; } +bool MemoryTransfer::split(Instruction *After, Instruction *Before) { + assert((After || Before) && + "Must have a place to move the split runtime call"); + + auto *M = RuntimeCall->getModule(); + auto &IRBuilder = InfoCache.OMPBuilder; + // Add "issue" runtime call declaration. + // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32, + // i8**, i8**, i64*, i64*) + FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction( + *M, OMPRTL___tgt_target_data_begin_issue); + + // Change RuntimeCall callsite for its asynchronous version. + SmallVector Args; + Args.reserve(RuntimeCall->getNumArgOperands()); + for (auto &Arg : RuntimeCall->args()) + Args.push_back(Arg.get()); + + CallInst *IssueCallsite = CallInst::Create( + IssueDecl, ArrayRef(Args), "handle", RuntimeCall); + RuntimeCall->removeFromParent(); + RuntimeCall->deleteValue(); + Issue.insert(IssueCallsite); + + // Add "wait" runtime call declaration. + // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info) + FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction( + *M, OMPRTL___tgt_target_data_begin_wait); + + // Add "wait" call site. + const unsigned WaitNumParams = 2; + Value *WaitParams[] = { + IssueCallsite->getArgOperand(0), // device_id. + IssueCallsite // returned handle. + }; + CallInst::Create( + WaitDecl, ArrayRef(WaitParams, WaitNumParams), /*NameStr=*/"", + /*InsertBefore=*/(Instruction *)nullptr); + + // TODO: Move the "issue" after After and the "wait" before Before. + return true; +} + std::unique_ptr OffloadArray::initialize( AllocaInst &Array, Instruction &Before, InformationCache &InfoCache) { if (!Array.getAllocatedType()->isArrayTy()) { @@ -802,7 +859,8 @@ Changed |= runAttributor(); Changed |= deduplicateRuntimeCalls(); Changed |= deleteParallelRegions(); - Changed |= hideMemTransfersLatency(); + if (SplitMemoryTransfers) + Changed |= hideMemTransfersLatency(); return Changed; } @@ -945,11 +1003,9 @@ return false; } - if (canBeMovedUpwards(MT) || canBeMovedDownwards(MT)) { - // TODO: Split runtime call. - } - - return false; + auto *After = canBeMovedUpwards(MT); + auto *Before = canBeMovedDownwards(MT); + return (After || Before) && MT.split(After, Before); }; RFI.foreachUse(SplitDataTransfer);