diff --git a/llvm/include/llvm/IR/Value.h b/llvm/include/llvm/IR/Value.h --- a/llvm/include/llvm/IR/Value.h +++ b/llvm/include/llvm/IR/Value.h @@ -576,8 +576,12 @@ /// /// If CanBeNull is set by this function the pointer can either be null or be /// dereferenceable up to the returned number of bytes. - uint64_t getPointerDereferenceableBytes(const DataLayout &DL, - bool &CanBeNull) const; + /// + /// If \p CtxI is given, determine the dereferenceable bytes at its position. + /// Otherwise, return the number of globally dereferenceable bytes. + uint64_t + getPointerDereferenceableBytes(const DataLayout &DL, bool &CanBeNull, + const Instruction *CtxI = nullptr) const; /// Returns an alignment of the pointer value. /// diff --git a/llvm/lib/Analysis/CaptureTracking.cpp b/llvm/lib/Analysis/CaptureTracking.cpp --- a/llvm/lib/Analysis/CaptureTracking.cpp +++ b/llvm/lib/Analysis/CaptureTracking.cpp @@ -352,7 +352,8 @@ // cannot lead to pointer escapes, because if it is not null it // must be a valid (in-bounds) pointer. bool CanBeNull; - if (O->getPointerDereferenceableBytes(I->getModule()->getDataLayout(), CanBeNull)) + if (O->getPointerDereferenceableBytes(I->getModule()->getDataLayout(), + CanBeNull, I)) break; } } diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp --- a/llvm/lib/Analysis/Loads.cpp +++ b/llvm/lib/Analysis/Loads.cpp @@ -67,8 +67,8 @@ DL, CtxI, DT, Visited); bool CheckForNonNull = false; - APInt KnownDerefBytes(Size.getBitWidth(), - V->getPointerDereferenceableBytes(DL, CheckForNonNull)); + APInt KnownDerefBytes(Size.getBitWidth(), V->getPointerDereferenceableBytes( + DL, CheckForNonNull, CtxI)); if (KnownDerefBytes.getBoolValue()) { if (KnownDerefBytes.uge(Size)) if (!CheckForNonNull || isKnownNonZero(V, DL, 0, nullptr, CtxI, DT)) diff --git a/llvm/lib/IR/Value.cpp b/llvm/lib/IR/Value.cpp --- a/llvm/lib/IR/Value.cpp +++ b/llvm/lib/IR/Value.cpp @@ -15,6 +15,8 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SetVector.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/CFG.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -41,6 +43,11 @@ "non-global-value-max-name-size", cl::Hidden, cl::init(1024), cl::desc("Maximum size for the name of non-global values.")); +static cl::opt MaxCheckedInstructionsForNoFreeInBetween( + "max-checked-instructions-for-no-free-in-between", cl::Hidden, cl::init(0), + cl::desc("Maximum number of instructions checked to verify memory was not " + "deallocated.")); + //===----------------------------------------------------------------------===// // Value Class //===----------------------------------------------------------------------===// @@ -601,36 +608,133 @@ return stripPointerCastsAndOffsets(this); } +/// Return true if there cannot be a memory deallocation (aka. free) in-between +/// \p SrcI and \p DstI. Thus, check if there is no direct or transitive "free" +/// call in-between. This includes no synchronization as another thread could +/// free the memory. +static bool noFreeInBetween(const Instruction *SrcI, const Instruction *DstI) { + const Function *F = SrcI->getFunction(); + // If the function is no-free and no-sync there cannot be a free. + if (F->hasFnAttribute(Attribute::NoFree) && + F->hasFnAttribute(Attribute::NoSync)) + return true; + + // If we do not want to check any instructions we give up now. + if (MaxCheckedInstructionsForNoFreeInBetween == 0) + return false; + + SmallPtrSet Visited; + SmallVector Worklist; + Worklist.push_back(DstI); + + do { + const Instruction *CurI = Worklist.pop_back_val(); + + // Never visit an instruction twice. + if (!Visited.insert(CurI).second) + continue; + + // Make sure we do not waste too much time trying to prove this. + if (Visited.size() > MaxCheckedInstructionsForNoFreeInBetween) + return false; + + // Only calls can deallocate, aka. free, memory or synchronize. + if (ImmutableCallSite ICS = ImmutableCallSite(CurI)) { + if (!ICS.hasFnAttr(Attribute::NoFree) || !ICS.hasFnAttr(Attribute::NoSync)) + return false; + } + + // Once SrcI is reached we are done traversing for this instruction. + if (CurI == SrcI) + continue; + + // If we reached the beginning of a block, look at the predecessors. + if (!CurI->getPrevNode()) { + const BasicBlock *CurBB = CurI->getParent(); + for (const BasicBlock *PredBB : predecessors(CurBB)) + Worklist.push_back(&PredBB->back()); + } + + } while (!Worklist.empty()); + + return true; +} + uint64_t Value::getPointerDereferenceableBytes(const DataLayout &DL, - bool &CanBeNull) const { + bool &CanBeNull, + const Instruction *CtxI) const { assert(getType()->isPointerTy() && "must be pointer"); uint64_t DerefBytes = 0; CanBeNull = false; if (const Argument *A = dyn_cast(this)) { - DerefBytes = A->getDereferenceableBytes(); + DerefBytes = A->getDereferenceableGloballyBytes(); if (DerefBytes == 0 && (A->hasByValAttr() || A->hasStructRetAttr())) { Type *PT = cast(A->getType())->getElementType(); if (PT->isSized()) DerefBytes = DL.getTypeStoreSize(PT); } + if (CtxI && DerefBytes == 0) { + uint64_t DerefBytesAtDef = A->getDereferenceableBytes(); + if (DerefBytesAtDef && + noFreeInBetween(&A->getParent()->getEntryBlock().front(), CtxI)) + DerefBytes = DerefBytesAtDef; + } if (DerefBytes == 0) { - DerefBytes = A->getDereferenceableOrNullBytes(); + DerefBytes = A->getDereferenceableOrNullBytesGlobally(); CanBeNull = true; } + if (CtxI && DerefBytes == 0) { + uint64_t DerefBytesAtDef = A->getDereferenceableOrNullBytes(); + if (DerefBytesAtDef && + noFreeInBetween(&A->getParent()->getEntryBlock().front(), CtxI)) { + DerefBytes = DerefBytesAtDef; + CanBeNull = true; + } + } } else if (const auto *Call = dyn_cast(this)) { - DerefBytes = Call->getDereferenceableBytes(AttributeList::ReturnIndex); + DerefBytes = + Call->getDereferenceableGloballyBytes(AttributeList::ReturnIndex); + if (CtxI && DerefBytes == 0) { + uint64_t DerefBytesAtDef = + Call->getDereferenceableBytes(AttributeList::ReturnIndex); + if (DerefBytesAtDef && noFreeInBetween(Call, CtxI)) + DerefBytes = DerefBytesAtDef; + } if (DerefBytes == 0) { - DerefBytes = - Call->getDereferenceableOrNullBytes(AttributeList::ReturnIndex); + DerefBytes = Call->getDereferenceableOrNullBytesGlobally( + AttributeList::ReturnIndex); CanBeNull = true; } + if (CtxI && DerefBytes == 0) { + uint64_t DerefBytesAtDef = + Call->getDereferenceableOrNullBytes(AttributeList::ReturnIndex); + if (DerefBytesAtDef && noFreeInBetween(Call, CtxI)) { + DerefBytes = DerefBytesAtDef; + CanBeNull = true; + } + } } else if (const LoadInst *LI = dyn_cast(this)) { - if (MDNode *MD = LI->getMetadata(LLVMContext::MD_dereferenceable)) { + if (MDNode *MD = + LI->getMetadata(LLVMContext::MD_dereferenceable_globally)) { ConstantInt *CI = mdconst::extract(MD->getOperand(0)); DerefBytes = CI->getLimitedValue(); } + if (CtxI && DerefBytes == 0) { + if (MDNode *MD = LI->getMetadata(LLVMContext::MD_dereferenceable)) { + ConstantInt *CI = mdconst::extract(MD->getOperand(0)); + DerefBytes = CI->getLimitedValue(); + } + } if (DerefBytes == 0) { + if (MDNode *MD = LI->getMetadata( + LLVMContext::MD_dereferenceable_or_null_globally)) { + ConstantInt *CI = mdconst::extract(MD->getOperand(0)); + DerefBytes = CI->getLimitedValue(); + } + CanBeNull = true; + } + if (CtxI && DerefBytes == 0) { if (MDNode *MD = LI->getMetadata(LLVMContext::MD_dereferenceable_or_null)) { ConstantInt *CI = mdconst::extract(MD->getOperand(0)); diff --git a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp --- a/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -493,7 +493,8 @@ CallSite CS(U); assert(CS && "Should only have direct calls!"); - if (!isDereferenceablePointer(CS.getArgument(ArgNo), DL)) + if (!isDereferenceablePointer(CS.getArgument(ArgNo), DL, + CS.getInstruction())) return false; } return true;