diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -3508,6 +3508,10 @@ /// Returns true if HeapToStack conversion is assumed to be possible. virtual bool isAssumedHeapToStack(CallBase &CB) const = 0; + /// Returns true if HeapToStack conversion is assumed and the CB is a + /// callsite to a free operation to be removed. + virtual bool isAssumedHeapToStackRemovedFree(CallBase &CB) const = 0; + /// Create an abstract attribute view for the position \p IRP. static AAHeapToStack &createForPosition(const IRPosition &IRP, Attributor &A); diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp --- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp +++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp @@ -5656,6 +5656,20 @@ return false; } + bool isAssumedHeapToStackRemovedFree(CallBase &CB) const override { + if (isValidState()) + for (auto &It : AllocationInfos) { + AllocationInfo &AI = *It.second; + if (AI.Status == AllocationInfo::INVALID) + continue; + + if (AI.PotentialFreeCalls.count(&CB)) + return true; + } + + return false; + } + ChangeStatus manifest(Attributor &A) override { assert(getState().isValidState() && "Attempted to manifest an invalid state!"); diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -2538,6 +2538,13 @@ static AAHeapToShared &createForPosition(const IRPosition &IRP, Attributor &A); + /// Returns true if HeapToShared conversion is assumed to be possible. + virtual bool isAssumedHeapToShared(CallBase &CB) const = 0; + + /// Returns true if HeapToShared conversion is assumed and the CB is a + /// callsite to a free operation to be removed. + virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0; + /// See AbstractAttribute::getName(). const std::string getName() const override { return "AAHeapToShared"; } @@ -2566,6 +2573,27 @@ /// See AbstractAttribute::trackStatistics(). void trackStatistics() const override {} + void findPotentialRemovedFreeCalls(Attributor &A) { + auto &OMPInfoCache = static_cast(A.getInfoCache()); + auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared]; + + PotentialRemovedFreeCalls.clear(); + // Update free call users of found malloc calls. + for (CallBase *CB : MallocCalls) { + SmallVector FreeCalls; + for (auto *U : CB->users()) { + CallBase *C = dyn_cast(U); + if (C && C->getCalledFunction() == FreeRFI.Declaration) + FreeCalls.push_back(C); + } + // TODO: Do we need to assume there is one, unique free user? + if (FreeCalls.size() != 1) + continue; + + PotentialRemovedFreeCalls.insert(FreeCalls.front()); + } + } + void initialize(Attributor &A) override { auto &OMPInfoCache = static_cast(A.getInfoCache()); auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared]; @@ -2573,6 +2601,23 @@ for (User *U : RFI.Declaration->users()) if (CallBase *CB = dyn_cast(U)) MallocCalls.insert(CB); + + findPotentialRemovedFreeCalls(A); + } + + bool isAssumedHeapToShared(CallBase &CB) const { + if (isValidState()) + if (MallocCalls.count(&CB)) + return true; + + return false; + } + + bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const { + if (isValidState()) + if (PotentialRemovedFreeCalls.count(&CB)) + return true; + return false; } ChangeStatus manifest(Attributor &A) override { @@ -2661,6 +2706,8 @@ MallocCalls.erase(CB); } + findPotentialRemovedFreeCalls(A); + if (NumMallocCalls != MallocCalls.size()) return ChangeStatus::CHANGED; @@ -2669,6 +2716,8 @@ /// Collection of all malloc calls in a function. SmallPtrSet MallocCalls; + /// Collection of potentially removed free calls in a function. + SmallPtrSet PotentialRemovedFreeCalls; }; struct AAKernelInfo : public StateWrapper { @@ -3378,6 +3427,14 @@ SPMDCompatibilityTracker.insert(&CB); ReachedUnknownParallelRegions.insert(&CB); break; + case OMPRTL___kmpc_alloc_shared: + case OMPRTL___kmpc_free_shared: + A.getAAFor(*this, IRPosition::function(*CB.getCaller()), + DepClassTy::REQUIRED); + A.getAAFor(*this, IRPosition::function(*CB.getCaller()), + DepClassTy::REQUIRED); + // Return without setting a fixpoint, to be resolved in updateImpl. + return; default: // Unknown OpenMP runtime calls cannot be executed in SPMD-mode, // generally. @@ -3396,12 +3453,63 @@ // sense to specialize attributes for call sites arguments instead of // redirecting requests to the callee argument. Function *F = getAssociatedFunction(); - const IRPosition &FnPos = IRPosition::function(*F); - auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); - if (getState() == FnAA.getState()) - return ChangeStatus::UNCHANGED; - getState() = FnAA.getState(); - return ChangeStatus::CHANGED; + + auto &OMPInfoCache = static_cast(A.getInfoCache()); + const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F); + + // If F is not a runtime function, propagate the AAKernelInfo of the callee. + if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) { + const IRPosition &FnPos = IRPosition::function(*F); + auto &FnAA = A.getAAFor(*this, FnPos, DepClassTy::REQUIRED); + if (getState() == FnAA.getState()) + return ChangeStatus::UNCHANGED; + getState() = FnAA.getState(); + return ChangeStatus::CHANGED; + } + + // F is a runtime function that allocates or frees memory, check + // AAHeapToStack and AAHeapToShared. + KernelInfoState StateBefore = getState(); + assert(It->getSecond() == OMPRTL___kmpc_alloc_shared || + It->getSecond() == OMPRTL___kmpc_free_shared); + + CallBase &CB = cast(getAssociatedValue()); + + auto &HeapToStackAA = A.getAAFor( + *this, IRPosition::function(*CB.getCaller()), DepClassTy::REQUIRED); + auto &HeapToSharedAA = A.getAAFor( + *this, IRPosition::function(*CB.getCaller()), DepClassTy::REQUIRED); + + RuntimeFunction RF = It->getSecond(); + + switch (RF) { + case OMPRTL___kmpc_alloc_shared: + if (HeapToStackAA.isAssumedHeapToStack(CB)) + indicateOptimisticFixpoint(); + else if (HeapToSharedAA.isAssumedHeapToShared(CB)) + indicateOptimisticFixpoint(); + else { + SPMDCompatibilityTracker.insert(&CB); + indicatePessimisticFixpoint(); + } + break; + case OMPRTL___kmpc_free_shared: + if (HeapToStackAA.isAssumedHeapToStackRemovedFree(CB)) + indicateOptimisticFixpoint(); + else if (HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB)) + indicateOptimisticFixpoint(); + else { + SPMDCompatibilityTracker.insert(&CB); + indicatePessimisticFixpoint(); + } + break; + default: + SPMDCompatibilityTracker.insert(&CB); + indicatePessimisticFixpoint(); + } + + return StateBefore == getState() ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; } };