Index: llvm/include/llvm/Transforms/IPO/OpenMPOpt.h =================================================================== --- llvm/include/llvm/Transforms/IPO/OpenMPOpt.h +++ llvm/include/llvm/Transforms/IPO/OpenMPOpt.h @@ -175,20 +175,27 @@ bool isFilled(); }; - CallBase *RuntimeCall; /// Call that involves a memotry transfer. + CallInst *RuntimeCall; /// Call that involves a memotry transfer. InformationCache &InfoCache; /// These help mapping the values in offload_baseptrs, offload_ptrs, and /// offload_sizes, respectively. + const unsigned BasePtrsArgNum = 2; std::unique_ptr BasePtrs = nullptr; + const unsigned PtrsArgNum = 3; std::unique_ptr Ptrs = nullptr; + const unsigned SizesArgNum = 4; std::unique_ptr Sizes = nullptr; /// Set of instructions that compose the argument setup for the call /// RuntimeCall. SetVector Issue; - MemoryTransfer(CallBase *RuntimeCall, InformationCache &InfoCache) : + /// Runtime call that will wait on the handle returned by the runtime call + /// in Issue. + CallInst *Wait; + + MemoryTransfer(CallInst *RuntimeCall, InformationCache &InfoCache) : RuntimeCall{RuntimeCall}, InfoCache{InfoCache} {} @@ -207,6 +214,11 @@ /// offload arrays. bool mayBeModifiedBy(Instruction *I); + /// Splits this object into its "issue" and "wait" corresponding runtime + /// calls. The "issue" is moved after \p After and the "wait" is moved + /// before \p Before. + bool split(Instruction *After, Instruction *Before); + private: /// Gets the setup instructions for each of the values in \p OA. These /// instructions are stored into Issue. @@ -218,6 +230,14 @@ /// Returns true if \p I may modify one of the values in \p Values. bool mayModify(Instruction *I, SmallVectorImpl &Values); + + /// Creates the StructureType %struct.tgt_async_info = type { i8* } + /// or returns a pointer to it if already exists. + Type *getOrCreateHandleType(); + + /// Removes from the function all the instructions in Issue and inserts + /// them after \p After. + void moveIssue(Instruction *After); }; /// The slice of the module we are allowed to look at. @@ -301,6 +321,10 @@ /// moved. Returns nullptr if the movement is not possible, or not worth it. Instruction *canBeMovedUpwards(MemoryTransfer &MT); + /// Returns a pointer to the instruction where the "wait" of \p MT can be + /// moved. Returns nullptr if the movement is not possible, or not worth it. + Instruction *canBeMovedDownwards(MemoryTransfer &MT); + static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, bool GlobalOnly, bool &SingleChoice); Index: llvm/lib/Transforms/IPO/OpenMPOpt.cpp =================================================================== --- llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -253,13 +253,10 @@ // arrays, offload_baseptrs, offload_ptrs, offload_sizes. // Therefore: // i8** %offload_baseptrs. - const unsigned BasePtrsArgNum = 2; Use *BasePtrsArg = RuntimeCall->arg_begin() + BasePtrsArgNum; // i8** %offload_ptrs. - const unsigned PtrsArgNum = 3; Use *PtrsArg = RuntimeCall->arg_begin() + PtrsArgNum; // i8** %offload_sizes. - const unsigned SizesArgNum = 4; Use *SizesArg = RuntimeCall->arg_begin() + SizesArgNum; const DataLayout &DL = InfoCache.getDL(); @@ -337,6 +334,10 @@ << RuntimeCall->getCaller()->getName() << "\n"); return false; } + auto *BasePtrsGEP = + cast(RuntimeCall->getArgOperand(BasePtrsArgNum)); + if (!Issue.count(BasePtrsGEP)) + Issue.insert(BasePtrsGEP); Success = getSetupInstructions(Ptrs); if (!Success) { @@ -346,6 +347,10 @@ << RuntimeCall->getCaller()->getName() << "\n"); return false; } + auto *PtrsGEP = + cast(RuntimeCall->getArgOperand(PtrsArgNum)); + if (!Issue.count(PtrsGEP)) + Issue.insert(PtrsGEP); if (Sizes) { Success = getSetupInstructions(Sizes); @@ -356,6 +361,10 @@ << RuntimeCall->getCaller()->getName() << "\n"); return false; } + auto *SizesGEP = + cast(RuntimeCall->getArgOperand(SizesArgNum)); + if (!Issue.count(SizesGEP)) + Issue.insert(SizesGEP); } return true; @@ -495,6 +504,80 @@ return true; } +bool MemoryTransfer::split(Instruction *After, Instruction *Before) { + assert((After || Before) && + "Must have a place to move the split runtime call"); + + auto *HandleType = getOrCreateHandleType(); + if (!HandleType) + return false; + + auto *M = RuntimeCall->getModule(); + // Add "issue" runtime call declaration. + // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32, + // i8**, i8**, i64*, i64*) + auto IssueParams = RuntimeCall->getFunctionType()->params(); + FunctionCallee IssueDecl = M->getOrInsertFunction( + "tgt_target_data_begin_issue", + FunctionType::get(HandleType, IssueParams, false)); + + // Change RuntimeCall callsite for its asynchronous version. + RuntimeCall->mutateFunctionType(IssueDecl.getFunctionType()); + RuntimeCall->setCalledFunction(IssueDecl); + Issue.insert(RuntimeCall); + + // Add "wait" runtime call declaration. + // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info) + const unsigned WaitNumParams = 2; + Type ** WaitDeclParams = new Type*[WaitNumParams]; + WaitDeclParams[0] = Type::getInt64Ty(M->getContext()); + WaitDeclParams[1] = HandleType; + + FunctionCallee WaitDecl = M->getOrInsertFunction( + "tgt_target_data_begin_wait", + FunctionType::get( + Type::getVoidTy(M->getContext()), + ArrayRef(WaitDeclParams, WaitNumParams), false)); + + // Add "wait" call site. + Value** WaitParams = new Value*[WaitNumParams]; + WaitParams[0] = RuntimeCall->getArgOperand(0); // device_id. + WaitParams[1] = RuntimeCall; // handle returned. + Wait = CallInst::Create( + WaitDecl, ArrayRef(WaitParams, WaitNumParams), "", + (Instruction *)nullptr); + + // Move wait. + if (!Before) + Wait->insertAfter(RuntimeCall); + else + Wait->insertBefore(Before); + + if (After) + moveIssue(After); + + return true; +} + +Type *MemoryTransfer::getOrCreateHandleType() { + auto *M = RuntimeCall->getModule(); + // If already exists do not create it. + for (auto *ST : M->getIdentifiedStructTypes()) + if (ST->getName() == "struct.tgt_async_info") + return ST; + + return StructType::create("struct.tgt_async_info", + Type::getInt8PtrTy(M->getContext())); +} + +void MemoryTransfer::moveIssue(Instruction *After) { + for (auto *I : Issue) { + I->removeFromParent(); + I->insertAfter(After); + After = I; + } +} + std::unique_ptr OffloadArray::initialize( AllocaInst &Array, Instruction &Before, InformationCache &InfoCache) { if (!Array.getAllocatedType()->isArrayTy()) { @@ -945,10 +1028,9 @@ return false; } - if (auto *I = canBeMovedUpwards(MT)) { - // TODO: Split call and move "issue" below I. - } - return false; + auto *After = canBeMovedUpwards(MT); + auto *Before = canBeMovedDownwards(MT); + return (After || Before) && MT.split(After, Before); }; RFI.foreachUse(SplitDataTransfer); @@ -958,7 +1040,7 @@ Instruction *OpenMPOpt::canBeMovedUpwards(MemoryTransfer &MT) { assert(MT.Issue.size() > 0 && "There's not set of instructions to be moved!"); - CallBase *RC = MT.RuntimeCall; + CallInst *RC = MT.RuntimeCall; auto *MSSAResult = OMPInfoCache.getAnalysisResultForFunction( *RC->getCaller()); @@ -987,6 +1069,34 @@ return nullptr; } +Instruction *OpenMPOpt::canBeMovedDownwards(MemoryTransfer &MT) { + assert(MT.Issue.size() > 0 && "There's not set of instructions to be moved!"); + + // FIXME: This traverses only the BasicBlock where MT is. Make it traverse + // the CFG. + GlobalValue *TgtTargetDecl = M.getNamedValue("__tgt_target"); + GlobalValue *TgtTargetTeamsDecl = M.getNamedValue("__tgt_target_teams"); + GlobalValue *TgtTargetDataEndDecl = M.getNamedValue("__tgt_target_data_end"); + CallInst *RC = MT.RuntimeCall; + auto *I = RC->getNextNode(); + while (I) { + if (auto *C = dyn_cast(I)) { + auto *Callee = C->getCalledFunction(); + if (Callee == TgtTargetDecl) + return I; + if (Callee == TgtTargetTeamsDecl) + return I; + if (Callee == TgtTargetDataEndDecl) + return I; + } + + I = I->getNextNode(); + } + + // Return end of BasicBlock. + return &*(RC->getParent()->end()); +} + Value *OpenMPOpt::combinedIdentStruct(Value *CurrentIdent, Value *NextIdent, bool GlobalOnly, bool &SingleChoice) { if (CurrentIdent == NextIdent)