Index: llvm/lib/Transforms/IPO/OpenMPOpt.cpp =================================================================== --- llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -27,6 +27,7 @@ #include "llvm/Transforms/IPO/Attributor.h" #include "llvm/Transforms/Utils/CallGraphUpdater.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/MemorySSA.h" using namespace llvm; using namespace omp; @@ -497,6 +498,16 @@ return true; } + bool isInitialized() { return RuntimeCall != nullptr; } + + void moveIssue(Instruction &After) { + Instruction *Before = After.getNextNode(); + for (auto *I : Issue) { + I->removeFromParent(); + I->insertBefore(Before); + } + } + MemoryTransfer() = default; MemoryTransfer(const MemoryTransfer &) = delete; MemoryTransfer &operator=(const MemoryTransfer &) = delete; @@ -835,11 +846,12 @@ return false; LLVM_DEBUG(dumpMemoryTransferSetupInstructions(MT)); - // TODO: Check if can be moved upwards. bool WasSplit = false; - Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall); - if (WaitMovementPoint) - WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint); + Instruction *IssueNewLocation = canBeMovedUpwards(MT); + Instruction *WaitNewLocation = canBeMovedDownwards(*RTCall); + if (IssueNewLocation || WaitNewLocation) + WasSplit = splitTargetDataBeginRTC( + MT, IssueNewLocation, WaitNewLocation); Changed |= WasSplit; return WasSplit; @@ -948,6 +960,97 @@ LLVM_DEBUG(dbgs() << Printer.str()); } + Instruction *canBeMovedUpwards(MemoryTransfer &MT) { + assert(MT.isInitialized() && "Initialize MemoryTransfer first!"); + + CallInst *RC = MT.RuntimeCall; + auto *MSSAResult = + OMPInfoCache.getAnalysisResultForFunction( + *RC->getCaller()); + if (!MSSAResult) + return nullptr; + + auto &MSSA = MSSAResult->getMSSA(); + auto *MSSAWalker = MSSA.getWalker(); + const auto *LiveOnEntry = MSSA.getLiveOnEntryDef(); + auto *MemAccess = MSSAWalker->getClobberingMemoryAccess(RC); + + while (MemAccess != LiveOnEntry) { + if (!isa(MemAccess)) + continue; + + auto *MemInst = (cast(MemAccess))->getMemoryInst(); + if (mayBeModifiedBy(MT, *MemInst)) { + // If MemInst is not the instruction immediately before the setup + // instructions, that is, MT.Issue, the movement is worth it. + if (!MT.Issue.count(MemInst->getNextNode())) + return MemInst; + + return nullptr; + } + + MemAccess = MSSAWalker->getClobberingMemoryAccess(MemAccess); + } + + return nullptr; + } + + bool mayBeModifiedBy(MemoryTransfer &MT, Instruction &I) { + if (MT.Issue.count(&I)) + return false; + + if (mayModify(I, MT.OffloadArrays[0].StoredValues)) + return true; + if (mayModify(I, MT.OffloadArrays[1].StoredValues)) + return true; + if (MT.OffloadArrays[2].isInitialized()) { + if (mayModify(I, MT.OffloadArrays[2].StoredValues)) + return true; + } + + return false; + } + + bool mayModify(Instruction &I, const SmallVectorImpl &Values) { + auto *AAResults = OMPInfoCache.getAnalysisResultForFunction( + *I.getFunction()); + + if (!AAResults) + return true; + + if (isa(&I)) { + auto *Dst = getUnderlyingObject(I.getOperand(1)); + for (auto *V : Values) { + if (Dst == V) { + return true; + } + } + } else if (isa(&I)) { + for (auto *V : Values) { + // FIXME: This usage of the AAResults is not working properly. It always + // returns that the call instruction I may modify a value V. + // For example: + // define i32 @func(double* noalias %a) { + // ... + // %1 = call i32 @rand() + // ... + // } + // The getModRefInfo always returns that rand() modifies %a, even + // though it has the noalias attribute. + auto ModRefResult = + AAResults->getModRefInfo( + &I, MemoryLocation( + V, LocationSize::precise( + V->getType()->getPrimitiveSizeInBits())) + ); + if (isModSet(ModRefResult)) + return true; + } + } + + return true; + } + /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be /// moved. Returns nullptr if the movement is not possible, or not worth it. Instruction *canBeMovedDownwards(CallInst &RuntimeCall) { @@ -977,10 +1080,13 @@ return RuntimeCall.getParent()->getTerminator(); } - /// Splits \p RuntimeCall into its "issue" and "wait" counterparts. - bool splitTargetDataBeginRTC(CallInst &RuntimeCall, - Instruction &WaitMovementPoint) { + /// Splits \p MT into its "issue" and "wait" counterparts. + bool splitTargetDataBeginRTC(MemoryTransfer &MT, + Instruction *After, Instruction *Before) { + assert(MT.isInitialized() && "Must initialize MemoryTransfer first!"); + auto &IRBuilder = OMPInfoCache.OMPBuilder; + CallInst &RuntimeCall = *MT.RuntimeCall; // Add "issue" runtime call declaration: // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32, // i8**, i8**, i64*, i64*) @@ -995,6 +1101,7 @@ CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, "handle", &RuntimeCall); RuntimeCall.eraseFromParent(); + MT.Issue.insert(IssueCallsite); // Add "wait" runtime call declaration: // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info) @@ -1006,7 +1113,13 @@ IssueCallsite->getArgOperand(0), // device_id. IssueCallsite // returned handle. }; - CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint); + + if (!Before) + Before = IssueCallsite; + CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", Before); + + if (After) + MT.moveIssue(*After); return true; }