Index: llvm/include/llvm/Transforms/IPO/Attributor.h =================================================================== --- llvm/include/llvm/Transforms/IPO/Attributor.h +++ llvm/include/llvm/Transforms/IPO/Attributor.h @@ -455,6 +455,27 @@ return *this; } + /// Take minimum of assumed and \p Value. + IntegerState &takeAssumedMinimum(base_t Value) { + // Make sure we never loose "known value". + Assumed = std::max(std::min(Assumed, Value), Known); + return *this; + } + + /// Take maximum of known and \p Value. + IntegerState &takeKnownMaximum(base_t Value) { + // Make sure we never loose "known value". + Assumed = std::max(Value, Assumed); + Known = std::max(Value, Known); + return *this; + } + + /// Equality for IntegerState. + bool operator==(const IntegerState &R) const { + return this->getAssumed() == R.getAssumed() && + this->getKnown() == R.getKnown(); + } + private: /// The known state encoding in an integer of type base_t. base_t Known = getWorstState(); @@ -740,6 +761,45 @@ /// The identifier used by the Attributor for this class of attributes. static constexpr Attribute::AttrKind ID = Attribute::NonNull; }; + +/// An abstract interface for all dereferenceable attribute. +struct AADereferenceable : public AbstractAttribute { + + /// See AbstractAttribute::AbstractAttribute(...). + AADereferenceable(Value &V, InformationCache &InfoCache) + : AbstractAttribute(V, InfoCache) {} + + /// See AbstractAttribute::AbstractAttribute(...). + AADereferenceable(Value *AssociatedVal, Value &AnchoredValue, + InformationCache &InfoCache) + : AbstractAttribute(AssociatedVal, AnchoredValue, InfoCache) {} + + /// Return true if we assume that the underlying value is nonnull. + virtual bool isAssumedNonNull() const = 0; + + /// Return true if we know that underlying value is nonnull. + virtual bool isKnownNonNull() const = 0; + + /// Return true if we assume that underlying value is dereferenceable(_or_null) + /// globally. + virtual bool isAssumedGlobal() const = 0; + + /// Return true if we know that underlying value is + /// dereferenceable(_or_null) globally. + virtual bool isKnownGlobal() const = 0; + + /// Return assumed dereferenceable bytes. + virtual uint32_t getAssumedDereferenceableBytes() const = 0; + + /// Return known dereferenceable bytes. + virtual uint32_t getKnownDereferenceableBytes() const = 0; + + /// See AbastractState::getAttrKind(). + Attribute::AttrKind getAttrKind() const override { return ID; } + + /// The identifier used by the Attributor for this class of attributes. + static constexpr Attribute::AttrKind ID = Attribute::Dereferenceable; +}; } // end namespace llvm #endif // LLVM_TRANSFORMS_IPO_FUNCTIONATTRS_H Index: llvm/lib/Transforms/IPO/Attributor.cpp =================================================================== --- llvm/lib/Transforms/IPO/Attributor.cpp +++ llvm/lib/Transforms/IPO/Attributor.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" @@ -264,6 +265,14 @@ Attrs = Attrs.addAttribute(Ctx, AttrIdx, Attr); return true; } + if (Attr.isIntAttribute()) { + Attribute::AttrKind Kind = Attr.getKindAsEnum(); + if (Attrs.hasAttribute(AttrIdx, Kind)) + if (isEqualOrWorse(Attr, Attrs.getAttribute(AttrIdx, Kind))) + return false; + Attrs = Attrs.addAttribute(Ctx, AttrIdx, Attr); + return true; + } llvm_unreachable("Expected enum or string attribute!"); } @@ -1065,7 +1074,9 @@ // Already nonnull. if (F.getAttributes().hasAttribute(AttributeList::ReturnIndex, - Attribute::NonNull)) + Attribute::NonNull) || + F.getAttributes().hasAttribute(AttributeList::ReturnIndex, + Attribute::Dereferenceable)) indicateOptimisticFixpoint(); } @@ -1102,7 +1113,7 @@ /// See AbstractAttriubute::initialize(...). void initialize(Attributor &A) override { Argument *Arg = cast(getAssociatedValue()); - if (Arg->hasNonNullAttr()) + if (Arg->hasNonNullAttr() || Arg->hasAttribute(Attribute::Dereferenceable)) indicateOptimisticFixpoint(); } @@ -1124,6 +1135,7 @@ CallSite CS(&getAnchoredValue()); if (isKnownNonZero(getAssociatedValue(), getAnchorScope().getParent()->getDataLayout()) || + CS.paramHasAttr(ArgNo, Attribute::Dereferenceable) || CS.paramHasAttr(ArgNo, getAttrKind())) indicateOptimisticFixpoint(); } @@ -1194,6 +1206,344 @@ return ChangeStatus::UNCHANGED; } +/// -------------------- Dereferenceable Argument Attribute -------------------- + +struct DerefState : AbstractState { + IntegerState DerefBytesState; + IntegerState NonNullGlobalState; + enum Deref { + DEREF_NONNULL = 1 << 0, + DEREF_GLOBAL = 1 << 1, + }; + + /// See AbstractState::isValidState() + bool isValidState() const override { return DerefBytesState.isValidState(); } + + // See AbstractState::isAtFixpoint() + bool isAtFixpoint() const override { + return DerefBytesState.isAtFixpoint() && NonNullGlobalState.isAtFixpoint(); + } + + /// See AbstractState::indicateOptimisticFixpoint(...) + void indicateOptimisticFixpoint() override { + DerefBytesState.indicateOptimisticFixpoint(); + NonNullGlobalState.indicateOptimisticFixpoint(); + } + + /// See AbstractState::indicatePessimisticFixpoint(...) + void indicatePessimisticFixpoint() override { + DerefBytesState.indicatePessimisticFixpoint(); + NonNullGlobalState.indicatePessimisticFixpoint(); + } + + /// Update known dereferenceable bytes. + void takeKnownDerefBytesMaximum(uint64_t Bytes) { + DerefBytesState.takeKnownMaximum(Bytes); + } + + /// Update assumed dereferenceable bytes. + void takeAssumedDerefBytesMinimum(uint64_t Bytes) { + DerefBytesState.takeAssumedMinimum(Bytes); + } + + /// Equality for DerefState. + bool operator==(const DerefState &R) { + return this->DerefBytesState == R.DerefBytesState && + this->NonNullGlobalState == R.NonNullGlobalState; + } +}; +struct AADereferenceableImpl : AADereferenceable, DerefState { + + AADereferenceableImpl(Value &V, InformationCache &InfoCache) + : AADereferenceable(V, InfoCache) {} + + AADereferenceableImpl(Value *AssociatedVal, Value &AnchoredValue, + InformationCache &InfoCache) + : AADereferenceable(AssociatedVal, AnchoredValue, InfoCache) {} + + /// See AbstractAttribute::getState() + /// { + AbstractState &getState() override { return *this; } + const AbstractState &getState() const override { return *this; } + /// } + + /// See AADereferenceable::getAssumedDereferenceableBytes(). + uint32_t getAssumedDereferenceableBytes() const override { + return DerefBytesState.getAssumed(); + } + + /// See AADereferenceable::getKnownDereferenceableBytes(). + uint32_t getKnownDereferenceableBytes() const override { + return DerefBytesState.getKnown(); + } + + // Helper function for syncing nonnull state. + void syncNonNull(const AANonNull *NonNullAA) { + if (!NonNullAA) { + NonNullGlobalState.removeAssumedBits(DEREF_NONNULL); + return; + } + + if (NonNullAA->isKnownNonNull()) + NonNullGlobalState.addKnownBits(DEREF_NONNULL); + + if (!NonNullAA->isAssumedNonNull()) + NonNullGlobalState.removeAssumedBits(DEREF_NONNULL); + } + + /// See AADereferenceable::isAssumedGlobal(). + bool isAssumedGlobal() const override { + return NonNullGlobalState.isAssumed(DEREF_GLOBAL); + } + + /// See AADereferenceable::isKnownGlobal(). + bool isKnownGlobal() const override { + return NonNullGlobalState.isKnown(DEREF_GLOBAL); + } + + /// See AADereferenceable::isAssumedNonNull(). + bool isAssumedNonNull() const override { + return NonNullGlobalState.isAssumed(DEREF_NONNULL); + } + + /// See AADereferenceable::isKnownNonNull(). + bool isKnownNonNull() const override { + return NonNullGlobalState.isKnown(DEREF_NONNULL); + } + + void getDeducedAttributes(SmallVectorImpl &Attrs) const override { + LLVMContext &Ctx = AnchoredVal.getContext(); + + // TODO: Add *_globally support + if (isAssumedNonNull()) + Attrs.emplace_back(Attribute::getWithDereferenceableBytes( + Ctx, getAssumedDereferenceableBytes())); + else + Attrs.emplace_back(Attribute::getWithDereferenceableOrNullBytes( + Ctx, getAssumedDereferenceableBytes())); + } + uint64_t computeAssumedDerefenceableBytes(Attributor &A, Value &V); + + /// See AbstractAttribute::getAsStr(). + const std::string getAsStr() const override { + if (!getAssumedDereferenceableBytes()) + return "unknown-dereferenceable"; + return std::string("dereferenceable") + + (isAssumedNonNull() ? "" : "_or_null") + + (isAssumedGlobal() ? "_globally" : "") + "<" + + std::to_string(getKnownDereferenceableBytes()) + "-" + + std::to_string(getAssumedDereferenceableBytes()) + ">"; + } +}; + +struct AADereferenceableReturned : AADereferenceableImpl { + AADereferenceableReturned(Function &F, InformationCache &InfoCache) + : AADereferenceableImpl(F, InfoCache) {} + + /// See AbstractAttribute::getManifestPosition(). + ManifestPosition getManifestPosition() const override { return MP_RETURNED; } + + /// See AbstractAttriubute::initialize(...). + void initialize(Attributor &A) override { + Function &F = getAnchorScope(); + + if (F.getAttributes().hasAttribute(AttributeList::ReturnIndex, + Attribute::Dereferenceable)) + takeKnownDerefBytesMaximum( + F.getDereferenceableBytes(AttributeList::ReturnIndex)); + + if (F.getAttributes().hasAttribute(AttributeList::ReturnIndex, + Attribute::DereferenceableOrNull)) + takeKnownDerefBytesMaximum( + F.getDereferenceableBytesOrNull(AttributeList::ReturnIndex)); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override; +}; + +// Helper function that returns dereferenceable bytes. +uint64_t calcDifferenceIfBaseIsNonNull(int64_t DerefBytes, int64_t Offset, + bool IsNonNull) { + if (!IsNonNull) + return 0; + return std::max((int64_t)0, DerefBytes - Offset); +} + +uint64_t AADereferenceableImpl::computeAssumedDerefenceableBytes(Attributor &A, + Value &V) { + // First, we try to get information about V from Attributor. + auto *DerefAA = A.getAAFor(*this, V); + if (DerefAA) + return DerefAA->getAssumedDereferenceableBytes(); + + // Otherwise, we try to compute assumed bytes from base pointer. + const DataLayout &DL = getAnchorScope().getParent()->getDataLayout(); + unsigned IdxWidth = + DL.getIndexSizeInBits(V.getType()->getPointerAddressSpace()); + APInt Offset(IdxWidth, 0); + Value *Base = V.stripAndAccumulateInBoundsConstantOffsets(DL, Offset); + + auto *BaseDerefAA = A.getAAFor(*this, *Base); + + if (BaseDerefAA) + return calcDifferenceIfBaseIsNonNull( + BaseDerefAA->getAssumedDereferenceableBytes(), Offset.getSExtValue(), + BaseDerefAA->isAssumedNonNull()); + + // Then, use IR information. + + bool BaseKnownCanBeNull = false; + if (uint64_t BaseKnownDerefBytes = + Base->getPointerDereferenceableBytes(DL, BaseKnownCanBeNull)) + return calcDifferenceIfBaseIsNonNull( + BaseKnownDerefBytes, Offset.getSExtValue(), !BaseKnownCanBeNull); + + if (isDereferenceablePointer(Base, Base->getType(), DL)) + return calcDifferenceIfBaseIsNonNull( + DL.getTypeStoreSize(Base->getType()->getPointerElementType()), + Offset.getSExtValue(), true); + + return 0; +} +ChangeStatus AADereferenceableReturned::updateImpl(Attributor &A) { + Function &F = getAnchorScope(); + auto BeforeState = static_cast(*this); + + syncNonNull(A.getAAFor(*this, F)); + + auto *AARetVal = A.getAAFor(*this, F); + if (!AARetVal) { + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; + } + + std::function Pred = [&](Value &RV) -> bool { + takeAssumedDerefBytesMinimum(computeAssumedDerefenceableBytes(A, RV)); + return isValidState(); + }; + + if (AARetVal->checkForallReturnedValues(Pred)) + return BeforeState == static_cast(*this) + ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; + + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; +} + +struct AADereferenceableArgument : AADereferenceableImpl { + AADereferenceableArgument(Argument &A, InformationCache &InfoCache) + : AADereferenceableImpl(A, InfoCache) {} + + /// See AbstractAttribute::getManifestPosition(). + ManifestPosition getManifestPosition() const override { return MP_ARGUMENT; } + + /// See AbstractAttriubute::initialize(...). + void initialize(Attributor &A) override { + Argument *Arg = cast(getAssociatedValue()); + if (Arg->hasAttribute(Attribute::Dereferenceable)) + DerefBytesState.addKnownBits(Arg->getDereferenceableBytes()); + if (Arg->hasAttribute(Attribute::DereferenceableOrNull)) + DerefBytesState.addKnownBits(Arg->getDereferenceableOrNullBytes()); + } + + /// See AbstractAttribute::updateImpl(...). + ChangeStatus updateImpl(Attributor &A) override; +}; + +ChangeStatus AADereferenceableArgument::updateImpl(Attributor &A) { + Function &F = getAnchorScope(); + Argument &Arg = cast(getAnchoredValue()); + + auto BeforeState = static_cast(*this); + + unsigned ArgNo = Arg.getArgNo(); + + syncNonNull(A.getAAFor(*this, F, ArgNo)); + + // Callback function + std::function CallSiteCheck = [&](CallSite CS) -> bool { + assert(CS && "Sanity check: Call site was not initialized properly!"); + + auto *DereferenceableAA = + A.getAAFor(*this, *CS.getInstruction(), ArgNo); + + // Check that DereferenceableAA is AADereferenceableCallSiteArgument. + if (DereferenceableAA) { + ImmutableCallSite ICS(&DereferenceableAA->getAnchoredValue()); + if (ICS && CS.getInstruction() == ICS.getInstruction()) { + takeAssumedDerefBytesMinimum( + DereferenceableAA->getAssumedDereferenceableBytes()); + return isValidState(); + } + } + + takeAssumedDerefBytesMinimum( + computeAssumedDerefenceableBytes(A, *CS.getArgOperand(ArgNo))); + + return isValidState(); + }; + + if (!A.checkForAllCallSites(F, CallSiteCheck, true)) { + indicatePessimisticFixpoint(); + return ChangeStatus::CHANGED; + } + return BeforeState == static_cast(*this) ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; +} + +/// Dereferenceable attribute for a call site argument. +struct AADereferenceableCallSiteArgument : AADereferenceableImpl { + + /// See AADereferenceableImpl::AADereferenceableImpl(...). + AADereferenceableCallSiteArgument(CallSite CS, unsigned ArgNo, + InformationCache &InfoCache) + : AADereferenceableImpl(CS.getArgOperand(ArgNo), *CS.getInstruction(), + InfoCache), + ArgNo(ArgNo) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + CallSite CS(&getAnchoredValue()); + if (CS.paramHasAttr(ArgNo, Attribute::Dereferenceable)) + takeKnownDerefBytesMaximum(CS.getDereferenceableBytes(ArgNo)); + + if (CS.paramHasAttr(ArgNo, Attribute::DereferenceableOrNull)) + takeKnownDerefBytesMaximum(CS.getDereferenceableOrNullBytes(ArgNo)); + } + + /// See AbstractAttribute::updateImpl(Attributor &A). + ChangeStatus updateImpl(Attributor &A) override; + + /// See AbstractAttribute::getManifestPosition(). + ManifestPosition getManifestPosition() const override { + return MP_CALL_SITE_ARGUMENT; + }; + + // Return argument index of associated value. + int getArgNo() const { return ArgNo; } + +private: + unsigned ArgNo; +}; + +ChangeStatus AADereferenceableCallSiteArgument::updateImpl(Attributor &A) { + // NOTE: Never look at the argument of the callee in this method. + // If we do this, "dereferenceable" is always deduced because of the + // assumption. + + Value &V = *getAssociatedValue(); + + auto BeforeState = static_cast(*this); + + syncNonNull(A.getAAFor(*this, getAnchoredValue(), ArgNo)); + takeAssumedDerefBytesMinimum(computeAssumedDerefenceableBytes(A, V)); + + return BeforeState == static_cast(*this) ? ChangeStatus::UNCHANGED + : ChangeStatus::CHANGED; +} + /// ---------------------------------------------------------------------------- /// Attributor /// ---------------------------------------------------------------------------- @@ -1395,12 +1745,21 @@ if (ReturnType->isPointerTy() && (!Whitelist || Whitelist->count(AANonNullReturned::ID))) registerAA(*new AANonNullReturned(F, InfoCache)); + + // Every function with pointer return type might be marked dereferenceable. + if (ReturnType->isPointerTy() && + (!Whitelist || Whitelist->count(AADereferenceableReturned::ID))) + registerAA(*new AADereferenceableReturned(F, InfoCache)); } - // Every argument with pointer type might be marked nonnull. for (Argument &Arg : F.args()) { - if (Arg.getType()->isPointerTy()) + if (Arg.getType()->isPointerTy()) { + // Every argument with pointer type might be marked nonnull. registerAA(*new AANonNullArgument(Arg, InfoCache)); + + // Every argument with pointer type might be marked nonnull. + registerAA(*new AADereferenceableArgument(Arg, InfoCache)); + } } // Walk all instructions to find more attribute opportunities and also @@ -1445,6 +1804,9 @@ // Call site argument attribute "non-null". registerAA(*new AANonNullCallSiteArgument(CS, i, InfoCache), i); + + // Call site argument attribute "dereferenceable". + registerAA(*new AADereferenceableCallSiteArgument(CS, i, InfoCache), i); } } } Index: llvm/test/Transforms/FunctionAttrs/arg_nocapture.ll =================================================================== --- llvm/test/Transforms/FunctionAttrs/arg_nocapture.ll +++ llvm/test/Transforms/FunctionAttrs/arg_nocapture.ll @@ -88,7 +88,7 @@ ; Other arguments are possible here due to the no-return behavior. ; ; FIXME: no-return missing -; CHECK: define noalias nonnull i32* @srec16(i32* nocapture readnone %a) +; CHECK: define noalias nonnull dereferenceable(4294967295) i32* @srec16(i32* nocapture readnone %a) define i32* @srec16(i32* %a) #0 { entry: %call = call i32* @srec16(i32* %a) Index: llvm/test/Transforms/FunctionAttrs/dereferenceable.ll =================================================================== --- llvm/test/Transforms/FunctionAttrs/dereferenceable.ll +++ llvm/test/Transforms/FunctionAttrs/dereferenceable.ll @@ -0,0 +1,45 @@ +; RUN: opt -attributor --attributor-disable=false -S < %s | FileCheck %s --check-prefixes=ATTRIBUTOR + + +; TEST 1 +; take mininimum of return values +; +define i32* @test1(i32* dereferenceable(4), double* dereferenceable(8), i1 zeroext) local_unnamed_addr { +; ATTRIBUTOR: define nonnull dereferenceable(4) i32* @test1(i32* nonnull dereferenceable(4), double* nonnull dereferenceable(8), i1 zeroext) + %4 = bitcast double* %1 to i32* + %5 = select i1 %2, i32* %0, i32* %4 + ret i32* %5 +} + +; TEST 2 +define i32* @test2(i32* dereferenceable_or_null(4), double* dereferenceable(8), i1 zeroext) local_unnamed_addr { +; ATTRIBUTOR: define dereferenceable_or_null(4) i32* @test2(i32* dereferenceable_or_null(4), double* nonnull dereferenceable(8), i1 zeroext) + %4 = bitcast double* %1 to i32* + %5 = select i1 %2, i32* %0, i32* %4 + ret i32* %5 +} + +; TEST 3 +; GEP inbounds +define i32* @test3_1(i32* dereferenceable(8)) local_unnamed_addr { +; ATTRIBUTOR: define nonnull dereferenceable(4) i32* @test3_1(i32* nonnull dereferenceable(8)) + %ret = getelementptr inbounds i32, i32* %0, i64 1 + ret i32* %ret +} + +; if %0 is null, we cann't assume dereferenceable_or_null for return value +define i32* @test3_2(i32* dereferenceable_or_null(32)) local_unnamed_addr { +; ATTRIBUTOR: define nonnull i32* @test3_2(i32* dereferenceable_or_null(32)) + %ret = getelementptr inbounds i32, i32* %0, i64 4 + ret i32* %ret +} + +define i32* @test3_3(i32* dereferenceable(8), i32* dereferenceable(16), i1) local_unnamed_addr { +; ATTRIBUTOR: define nonnull dereferenceable(4) i32* @test3_3(i32* nonnull dereferenceable(8), i32* nonnull dereferenceable(16), i1) local_unnamed_addr + %ret1 = getelementptr inbounds i32, i32* %0, i64 1 + %ret2 = getelementptr inbounds i32, i32* %1, i64 2 + %ret = select i1 %2, i32* %ret1, i32* %ret2 + ret i32* %ret +} + + Index: llvm/test/Transforms/FunctionAttrs/nonnull.ll =================================================================== --- llvm/test/Transforms/FunctionAttrs/nonnull.ll +++ llvm/test/Transforms/FunctionAttrs/nonnull.ll @@ -39,14 +39,14 @@ ; just never return period.) define i8* @test4_helper() { ; FNATTR: define noalias nonnull i8* @test4_helper -; ATTRIBUTOR: define nonnull i8* @test4_helper +; ATTRIBUTOR: define nonnull dereferenceable(4294967295) i8* @test4_helper %ret = call i8* @test4() ret i8* %ret } define i8* @test4() { ; FNATTR: define noalias nonnull i8* @test4 -; ATTRIBUTOR: define nonnull i8* @test4 +; ATTRIBUTOR: define nonnull dereferenceable(4294967295) i8* @test4 %ret = call i8* @test4_helper() ret i8* %ret } @@ -218,6 +218,15 @@ %tmp = call i32* @f1(i32* %arg) ret i32* null } + +; TEST 15 +define void @f15(i8* %arg) { +; ATTRIBUTOR: tail call void @use1(i8* nonnull dereferenceable(4) %arg) + + tail call void @use1(i8* dereferenceable(4) %arg) + ret void +} + ; Test propagation of nonnull callsite args back to caller. declare void @use1(i8* %x)