Index: llvm/include/llvm/Frontend/OpenMP/OMPKinds.def =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -427,7 +427,7 @@ VoidPtrPtr, VoidPtrPtr, Int64Ptr, Int64Ptr, VoidPtrPtr, VoidPtrPtr) __OMP_RTL(__tgt_target_data_begin_mapper_issue, false, Void, IdentPtr, Int64, Int32, VoidPtrPtr, VoidPtrPtr, Int64Ptr, Int64Ptr, VoidPtrPtr, VoidPtrPtr, AsyncInfoPtr) -__OMP_RTL(__tgt_target_data_begin_mapper_wait, false, Void, Int64, AsyncInfoPtr) +__OMP_RTL(__tgt_target_data_begin_mapper_wait, false, Void, IdentPtr, Int64, AsyncInfoPtr) __OMP_RTL(__tgt_target_data_end_mapper, false, Void, IdentPtr, Int64, Int32, VoidPtrPtr, VoidPtrPtr, Int64Ptr, Int64Ptr, VoidPtrPtr, VoidPtrPtr) __OMP_RTL(__tgt_target_data_end_nowait_mapper, false, Void, IdentPtr, Int64, Int32, Index: llvm/lib/Transforms/IPO/OpenMPOpt.cpp =================================================================== --- llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -24,6 +24,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/MemoryLocation.h" @@ -49,6 +50,7 @@ #include "llvm/Transforms/IPO/Attributor.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/CallGraphUpdater.h" +#include "llvm/Transforms/Utils/CodeExtractor.h" #include @@ -4182,6 +4184,348 @@ } }; +bool splitMapperToIssueAndWait(CallInst *RuntimeCall, + OMPInformationCache *InfoCache) { + auto &IRBuilder = InfoCache->OMPBuilder; + Function *F = RuntimeCall->getCaller(); + Module *M = F->getParent(); + Instruction *FirstInst = &(F->getEntryBlock().front()); + AllocaInst *Handle = new AllocaInst(IRBuilder.AsyncInfo, F->getAddressSpace(), + "handle", FirstInst); + + FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction( + *M, OMPRTL___tgt_target_data_begin_mapper_issue); + + // Change RuntimeCall call site for its asynchronous version. + SmallVector Args; + for (auto &Arg : RuntimeCall->args()) + Args.push_back(Arg.get()); + Args.push_back(Handle); + + CallInst::Create(IssueDecl, Args, /*NameStr=*/"", RuntimeCall); + + FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction( + *M, OMPRTL___tgt_target_data_begin_mapper_wait); + + Value *WaitParams[3] = { + RuntimeCall->getArgOperand(0), // loc + RuntimeCall->getArgOperand(OffloadArray::DeviceIDArgNum), // device_id. + Handle // handle to wait on. + }; + CallInst::Create(WaitDecl, WaitParams, "", RuntimeCall); + + RuntimeCall->eraseFromParent(); + + return true; +} + +SmallVector> +getUseTreeDuplicated(Instruction *I) { + SmallVector> UseTreeDuplicated; + UseTreeDuplicated.push_back({I, nullptr}); + + for (unsigned int i = 0; i < UseTreeDuplicated.size(); ++i) + for (User *U : dyn_cast(UseTreeDuplicated[i].first)->users()) { + if (dyn_cast(U)) + continue; + assert((dyn_cast(U) || dyn_cast(U) || + dyn_cast(U)) && + "Problem in data mapping"); + auto UInstr = cast(U); + if (isSafeToSpeculativelyExecute(UInstr)) // duplicate + UseTreeDuplicated.push_back({UInstr, UInstr->clone()}); + else + UseTreeDuplicated.push_back({UInstr, nullptr}); + } + + UseTreeDuplicated.erase(UseTreeDuplicated.begin()); + return UseTreeDuplicated; +} + +bool reorganizeDuplicatedInstructions( + SmallVector> UseTreeDuplicated, + Instruction *I) { + for (auto Utd : UseTreeDuplicated) { + // insert the duplicate if it exists + if (Utd.second) + Utd.second->insertBefore(Utd.first); + // move the instruction to the right BB + Utd.first->moveBefore(I); + } + return true; +} + +Function *outlineMapperIssueRT(CallInst &RuntimeCall) { + BasicBlock *RuntimeCallBB = RuntimeCall.getParent(); + CallInst *IssueRuntimeCall = &RuntimeCall; + + Value *BasePtrsArg = + IssueRuntimeCall->getArgOperand(OffloadArray::BasePtrsArgNum); + // i8** %offload_ptrs. + Value *PtrsArg = IssueRuntimeCall->getArgOperand(OffloadArray::PtrsArgNum); + // i8** %offload_sizes. + Value *SizesArg = IssueRuntimeCall->getArgOperand(OffloadArray::SizesArgNum); + + auto *BasePtrsArray = dyn_cast(getUnderlyingObject(BasePtrsArg)); + auto *PtrsArray = dyn_cast(getUnderlyingObject(PtrsArg)); + auto *SizesArray = dyn_cast(getUnderlyingObject(SizesArg)); + if (!BasePtrsArray || !PtrsArray || !SizesArray) + return nullptr; + + // split bb to be outlined + BasicBlock *MapperBB = + RuntimeCallBB->splitBasicBlock(IssueRuntimeCall, "mapper.bb"); + BasicBlock *NextBB = + MapperBB->splitBasicBlock(MapperBB->front().getNextNode(), "next.bb"); + + auto BasePtrsUses = getUseTreeDuplicated(BasePtrsArray); + auto PtrsUses = getUseTreeDuplicated(PtrsArray); + auto SizesArrayUses = getUseTreeDuplicated(SizesArray); + + reorganizeDuplicatedInstructions(BasePtrsUses, IssueRuntimeCall); + reorganizeDuplicatedInstructions(PtrsUses, IssueRuntimeCall); + reorganizeDuplicatedInstructions(SizesArrayUses, IssueRuntimeCall); + + for (auto Bpu : BasePtrsUses) + if (Bpu.second) + Bpu.first->replaceUsesOutsideBlock(Bpu.second, MapperBB); + for (auto Pu : PtrsUses) + if (Pu.second) + Pu.first->replaceUsesOutsideBlock(Pu.second, MapperBB); + for (auto Sau : SizesArrayUses) + if (Sau.second) + Sau.first->replaceUsesOutsideBlock(Sau.second, MapperBB); + + SmallVector ExtractBB{MapperBB}; + CodeExtractor CE(ExtractBB); + CodeExtractorAnalysisCache CEAC(*(MapperBB->getParent())); + Function *OutlinedFunc = CE.extractCodeRegion(CEAC); + OutlinedFunc->setName("mapper_issue_wrapper"); + + MergeBlockIntoPredecessor(NextBB); + + return OutlinedFunc; +} + +bool mapperFunctionAnnotation(Function *OutlinedMapperFunc) { + OutlinedMapperFunc->addFnAttr(Attribute::InaccessibleMemOrArgMemOnly); + int i = 0; + for (auto &A : OutlinedMapperFunc->args()) { + if (A.getType()->isPointerTy()) + OutlinedMapperFunc->addParamAttr(i, Attribute::NoCapture); + i++; + } + return true; +} + +CallInst *getRTFunctionCall(std::string RTFuncName, Function &F) { + CallInst *RTCall; + for (auto &I : instructions(F)) { + if (!dyn_cast(&I)) + continue; + else if (dyn_cast(&I)->getCalledFunction()->getName() == + RTFuncName) { + RTCall = dyn_cast(&I); + break; + } + } + return RTCall; +} + +// returns true if CI can move before or after I +bool canMoveThrough(CallInst *CI, Instruction *I, AliasAnalysis &AA) { + // TODO: check again for mayReadFromMemory + if (!(I->mayHaveSideEffects()) && + !(dyn_cast(I) && I->mayReadFromMemory())) + return true; + + auto MR = AA.getModRefInfo(I, CI); + if (isNoModRef(MR)) + return true; + + return false; +} + +bool canMoveThroughBlock(CallInst *CI, BasicBlock *B, AliasAnalysis &AA) { + for (auto &I : *B) + if (!canMoveThrough(CI, &I, AA)) + return false; + return true; +} + +// A -> B +bool isAAccessibleFromB(BasicBlock *A, BasicBlock *B) { + SmallVector SuccessorVec; + SuccessorVec.push_back(A); + for (unsigned int i = 0; i < SuccessorVec.size(); ++i) + for (BasicBlock *V : successors(SuccessorVec[i])) { + SuccessorVec.push_back(V); + if (V == B) + return true; + } + return false; +} +// check if a bb exists in a vector +bool bbIsInVec(SmallVector VecBB, BasicBlock *B) { + for (auto Vb : VecBB) + if (Vb == B) + return true; + return false; +} + +// This function gets the current BB of 'issue' and returns the +// next BB the 'issue' function can safely move to. +BasicBlock *findNextBBToCheckForMoving(CallInst *CI, BasicBlock *B, + AliasAnalysis &AA, DominatorTree &DT) { + auto Dom = DT.getNode(B)->getIDom(); + if (!Dom) + return B; + BasicBlock *DomBlock = Dom->getBlock(); + SmallVector SuccessorBB; + SuccessorBB.push_back(DomBlock); + + for (unsigned int i = 0; i < SuccessorBB.size(); ++i) { + for (BasicBlock *S : successors(SuccessorBB[i])) { + // If get to the destination from one branch, + // or a redundant BB(in case of having loops), do not add it + if (S == B || bbIsInVec(SuccessorBB, S)) + continue; + // If instruction cannot move through a block on the way, + // it cannot be moved to the dominator. + // However, if the BB is not accessible, it doesn't matter, + // and we keep looking. + if (!canMoveThroughBlock(CI, S, AA) && isAAccessibleFromB(S, B)) + return B; + SuccessorBB.push_back(S); + } + } + return DomBlock; +} + +bool moveIssueRTCInOrigBB(CallInst *IssueWrapperCall, AliasAnalysis &AA) { + Instruction *IssuMovePoint; + Instruction *I = IssueWrapperCall; + while ((I = I->getPrevNonDebugInstruction())) + if (!canMoveThrough(IssueWrapperCall, I, AA)) { + IssuMovePoint = I; + break; + } + if (!I) + return false; + // There is an instruction in the current BB that issue + // cannot move trough it. In this case, we don't need to check + // other blocks. + IssueWrapperCall->moveAfter(IssuMovePoint); + return true; +} + +bool moveWaitRTCInOrigBB(CallInst *IssueWrapperCall, CallInst *RTCallWait, + AliasAnalysis &AA) { + Instruction *I = RTCallWait; + Instruction *WaitMovePoint; + + while ((I = I->getNextNonDebugInstruction())) + if (!canMoveThrough(IssueWrapperCall, I, AA)) { + WaitMovePoint = I; + break; + } + if (!I) { + RTCallWait->moveBefore(RTCallWait->getParent()->getTerminator()); + return true; + } + + RTCallWait->moveBefore(WaitMovePoint); + + return true; +} + +bool moveIssueRTCInBB(CallInst *IssueWrapperCall, BasicBlock *CurrentBB, + AliasAnalysis &AA) { + Instruction *IssuMovePoint; + Instruction *I = CurrentBB->getTerminator(); + + while ((I = I->getPrevNonDebugInstruction())) + if (!canMoveThrough(IssueWrapperCall, I, AA)) { + IssuMovePoint = I; + break; + } + + // insert issue in the very begining of the BB. + if (!I) { + IssueWrapperCall->moveBefore(&(CurrentBB->front())); + return true; + } + IssueWrapperCall->moveAfter(IssuMovePoint); + return true; +} + +bool hideMemTransfersLatency(OMPInformationCache *InfoCache, Function &F, + FunctionAnalysisManager &FAM) { + + auto &RFI = InfoCache->RFIs[OMPRTL___tgt_target_data_begin_mapper]; + bool Changed = false; + auto AsyncMemTransfers = [&](Use &U, Function &Decl) { + auto *RTCall = OpenMPOpt::getCallIfRegularCall(U, &RFI); + if (!RTCall) + return Changed; + bool split = splitMapperToIssueAndWait(RTCall, InfoCache); + if (!split) + return Changed; + + CallInst *RTCallIssue = + getRTFunctionCall("__tgt_target_data_begin_mapper_issue", F); + CallInst *RTCallWait = + getRTFunctionCall("__tgt_target_data_begin_mapper_wait", F); + + Function *IssueWrapper = outlineMapperIssueRT(*RTCallIssue); + if (!IssueWrapper) // cannot outline the function for some reasons + return Changed; + + mapperFunctionAnnotation(IssueWrapper); + CallInst *IssueWrapperCall = + getRTFunctionCall(IssueWrapper->getName().str(), F); + + MergeBlockIntoPredecessor(IssueWrapperCall->getParent()); + BasicBlock *IssueBB = IssueWrapperCall->getParent(); + + AliasAnalysis &AA = FAM.getResult(F); + moveWaitRTCInOrigBB(IssueWrapperCall, RTCallWait, AA); + bool issueMovedInOrigBB = moveIssueRTCInOrigBB(IssueWrapperCall, AA); + + // cannot move beyond its original bb + if (issueMovedInOrigBB) { + Changed = true; + return Changed; + } + BasicBlock *CurrentBB = IssueBB; + BasicBlock *NextBB; + DominatorTree DT = DominatorTree(F); + + while ((NextBB = findNextBBToCheckForMoving(IssueWrapperCall, CurrentBB, AA, + DT))) { + // it cannot move anymore + if (NextBB == CurrentBB) + break; + CurrentBB = NextBB; + } + // It is when we have tried other BB and now we know we really cannot move + // it beyond its original BB, so the best thing to do is to move + // the issue to the original BB front. + if (CurrentBB == IssueBB) { + IssueWrapperCall->moveBefore(&(IssueBB->front())); + Changed = true; + return Changed; + } + if (moveIssueRTCInBB(IssueWrapperCall, CurrentBB, AA)) + Changed = true; + + return Changed; + }; + + RFI.foreachUse(AsyncMemTransfers, &F); + return Changed; +} + /// The call site kernel info abstract attribute, basically, what can we say /// about a call site with regards to the KernelInfoState. For now this simply /// forwards the information from the callee. @@ -5026,6 +5370,11 @@ if (PrintModuleAfterOptimizations) LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M); + if (HideMemoryTransferLatency) { + for (Function &F : M) + Changed |= hideMemTransfersLatency(&InfoCache, F, FAM); + return PreservedAnalyses::none(); + } if (Changed) return PreservedAnalyses::none(); Index: openmp/libomptarget/include/omptarget.h =================================================================== --- openmp/libomptarget/include/omptarget.h +++ openmp/libomptarget/include/omptarget.h @@ -187,10 +187,15 @@ __tgt_async_info AsyncInfo; DeviceTy &Device; + bool SyncFlag; public: - AsyncInfoTy(DeviceTy &Device) : Device(Device) {} - ~AsyncInfoTy() { synchronize(); } + AsyncInfoTy(DeviceTy &Device, bool Sf = 1) : Device(Device), SyncFlag(Sf) {} + + ~AsyncInfoTy() { + if (SyncFlag) + synchronize(); + } /// Implicit conversion to the __tgt_async_info which is used in the /// plugin interface. @@ -200,7 +205,6 @@ /// /// \returns OFFLOAD_FAIL or OFFLOAD_SUCCESS appropriately. int synchronize(); - /// Return a void* reference with a lifetime that is at least as long as this /// AsyncInfoTy object. The location can be used as intermediate buffer. void *&getVoidPtrLocation(); @@ -278,6 +282,14 @@ void **Args, int64_t *ArgSizes, int64_t *ArgTypes, map_var_info_t *ArgNames, void **ArgMappers); + +void __tgt_target_data_begin_mapper_issue( + ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase, + void **Args, int64_t *ArgSizes, int64_t *ArgTypes, map_var_info_t *ArgNames, + void **ArgMappers, __tgt_async_info *Handle); + +void __tgt_target_data_begin_mapper_wait(ident_t *Loc, int64_t DeviceId, + __tgt_async_info *Handle); void __tgt_target_data_begin_nowait_mapper( ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase, void **Args, int64_t *ArgSizes, int64_t *ArgTypes, map_var_info_t *ArgNames, Index: openmp/libomptarget/src/exports =================================================================== --- openmp/libomptarget/src/exports +++ openmp/libomptarget/src/exports @@ -15,6 +15,8 @@ __tgt_target_nowait; __tgt_target_teams_nowait; __tgt_target_data_begin_mapper; + __tgt_target_data_begin_mapper_issue; + __tgt_target_data_begin_mapper_wait; __tgt_target_data_end_mapper; __tgt_target_data_update_mapper; __tgt_target_mapper; Index: openmp/libomptarget/src/interface.cpp =================================================================== --- openmp/libomptarget/src/interface.cpp +++ openmp/libomptarget/src/interface.cpp @@ -70,6 +70,15 @@ int64_t *ArgTypes, map_var_info_t *ArgNames, void **ArgMappers) { + __tgt_target_data_begin_mapper_issue(Loc, DeviceId, ArgNum, ArgsBase, Args, + ArgSizes, ArgTypes, ArgNames, ArgMappers, + nullptr); +} + +EXTERN void __tgt_target_data_begin_mapper_issue( + ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase, + void **Args, int64_t *ArgSizes, int64_t *ArgTypes, map_var_info_t *ArgNames, + void **ArgMappers, __tgt_async_info *Handle) { TIMESCOPE_WITH_IDENT(Loc); DP("Entering data begin region for device %" PRId64 " with %d mappings\n", DeviceId, ArgNum); @@ -84,22 +93,45 @@ printKernelArguments(Loc, DeviceId, ArgNum, ArgSizes, ArgTypes, ArgNames, "Entering OpenMP data region"); #ifdef OMPTARGET_DEBUG - for (int I = 0; I < ArgNum; ++I) { + for (int i = 0; i < ArgNum; ++i) { DP("Entry %2d: Base=" DPxMOD ", Begin=" DPxMOD ", Size=%" PRId64 ", Type=0x%" PRIx64 ", Name=%s\n", - I, DPxPTR(ArgsBase[I]), DPxPTR(Args[I]), ArgSizes[I], ArgTypes[I], - (ArgNames) ? getNameFromMapping(ArgNames[I]).c_str() : "unknown"); + i, DPxPTR(ArgsBase[i]), DPxPTR(Args[i]), ArgSizes[i], ArgTypes[i], + (ArgNames) ? getNameFromMapping(ArgNames[i]).c_str() : "unknown"); } #endif - AsyncInfoTy AsyncInfo(Device); + AsyncInfoTy AsyncInfo(Device, !(Handle)); int Rc = targetDataBegin(Loc, Device, ArgNum, ArgsBase, Args, ArgSizes, ArgTypes, ArgNames, ArgMappers, AsyncInfo); - if (Rc == OFFLOAD_SUCCESS) + if (Rc == OFFLOAD_SUCCESS && Handle) + Handle->Queue = ((__tgt_async_info *)AsyncInfo)->Queue; + if (Rc == OFFLOAD_SUCCESS && !Handle) Rc = AsyncInfo.synchronize(); + handleTargetOutcome(Rc == OFFLOAD_SUCCESS, Loc); } +EXTERN void __tgt_target_data_begin_mapper_wait(ident_t *Loc, int64_t DeviceId, + __tgt_async_info *Handle) { + + TIMESCOPE_WITH_IDENT(Loc); + assert((Handle && Handle->Queue) && "Incomplete data mapping"); + + DP("Entering data begin region for device %" PRId64 " with %d mappings\n", + DeviceId); + // TODO: call to this function is redundant here, just used to get + // DevicId. + if (checkDeviceAndCtors(DeviceId, Loc)) { + DP("Not offloading to device %" PRId64 "\n", DeviceId); + return; + } + + DeviceTy &Device = *PM->Devices[DeviceId]; + if (Device.RTL->synchronize) + Device.RTL->synchronize(DeviceId, Handle); +} + EXTERN void __tgt_target_data_begin_nowait_mapper( ident_t *Loc, int64_t DeviceId, int32_t ArgNum, void **ArgsBase, void **Args, int64_t *ArgSizes, int64_t *ArgTypes, map_var_info_t *ArgNames,