Index: llvm/include/llvm/Analysis/MustExecute.h =================================================================== --- llvm/include/llvm/Analysis/MustExecute.h +++ llvm/include/llvm/Analysis/MustExecute.h @@ -392,13 +392,76 @@ const Instruction * getMustBeExecutedNextInstruction(MustBeExecutedIterator &It, const Instruction *PP); + struct KnownAbstractState { + virtual ~KnownAbstractState () = default; + + /// Conjunction for state. + KnownAbstractState &operator&=(const KnownAbstractState &); + + /// Disjunction for state. + KnownAbstractState &operator|=(const KnownAbstractState &); + + /// Return true if there is no need to explore more path. + virtual bool stopExplore() const; + + /// Multiplicative identity. + static KnownAbstractState Top(); + + /// Additive identity. + static KnownAbstractState Bottom(); + }; + + /// Explore the paths which must be executed after \p PP and collect states. + /// + /// \param Pred The function that takes Instruction and return + /// KnownAbstractState. + /// \param PP The program point for which the next instruction + /// that is guaranteed to execute is determined. + template + AState checkPredicateAfterInstruction( + std::function Pred, const Instruction *PP, + DenseMap &MemoizeResult) { + + // For each program point, we regard *begin(PP) as ID for the context. + const Instruction *PPBegin = *begin(PP); + + if (MemoizeResult.count(PPBegin)) + return *MemoizeResult.lookup(PPBegin); + + AState AS = AState::Bottom(); + + // Traverse all must-be-excuted-context and take disjunction. If we reach + // to conditional terminator, call this function recursively and take + // conjunction. + // + // (I_1 | I_2 | I_3 .. I_n) | (C_1 & C_2 & C_3 .. C_n) + + for (const Instruction *I : range(PP)) { + AS |= Pred(I); + if (AS.stopExplore()) + return AS; + + if (I->isTerminator() && I->getNumSuccessors() > 1) { + MemoizeResult[PPBegin] = &AS; + AState ChildrenAA = AState::Top(); + for (const BasicBlock *BB : successors(I)) + ChildrenAA &= checkPredicateAfterInstruction( + Pred, &BB->front(), MemoizeResult); + AS |= ChildrenAA; + } + } + + MemoizeResult[PPBegin] = &AS; + return AS; + } + + template + AAType checkPredicateAfterInstruction( + std::function Pred, const Instruction *PP) { + DenseMap MemoizeResult; + return checkPredicateAfterInstruction(Pred, PP, MemoizeResult); + } - /// Return true if there is a instruction that \p Pred holds in the context - /// around \p PP. - bool - checkPredicateAfterInstruction(std::function Pred, - const Instruction *PP, - bool ContinueExplorationWhenPredHolds = false); /// Parameter that limit the performed exploration. See the constructor for /// their meaning. Index: llvm/lib/Analysis/MustExecute.cpp =================================================================== --- llvm/lib/Analysis/MustExecute.cpp +++ llvm/lib/Analysis/MustExecute.cpp @@ -452,20 +452,6 @@ return nullptr; } -bool MustBeExecutedContextExplorer::checkPredicateAfterInstruction( - std::function Pred, const Instruction *PP, - bool ContinueExplorationWhenPredHolds) { - bool PredHolds = false; - for (const Instruction *I : range(PP)) { - LLVM_DEBUG(dbgs() << "[MustBeExecuted][checkPred] PP: " << *PP - << " I:" << *I << "\n"); - - PredHolds |= Pred(I); - if (PredHolds && !ContinueExplorationWhenPredHolds) - return true; - } - return PredHolds; -} MustBeExecutedIterator::MustBeExecutedIterator( MustBeExecutedContextExplorer &Explorer, const Instruction *I) Index: llvm/lib/Transforms/IPO/Attributor.cpp =================================================================== --- llvm/lib/Transforms/IPO/Attributor.cpp +++ llvm/lib/Transforms/IPO/Attributor.cpp @@ -1249,6 +1249,23 @@ private: unsigned ArgNo; }; + +struct ASBool : MustBeExecutedContextExplorer::KnownAbstractState { + bool State; + ASBool(bool State) : State(State) {} + bool stopExplore() const override { return State; } + ASBool &operator&=(const ASBool R) { + State &= R.State; + return *this; + } + ASBool &operator|=(const ASBool R) { + State |= R.State; + return *this; + } + static ASBool Bottom() { return ASBool(false); } + static ASBool Top() { return ASBool(true); } +}; + ChangeStatus AANonNullArgument::updateImpl(Attributor &A) { Function &F = getAnchorScope(); Argument &Arg = cast(getAnchoredValue()); @@ -1257,14 +1274,15 @@ MustBeExecutedContextExplorer &Explorer = InfoCache.getContextExplorer(); - auto Pred = [&](const Instruction *I) -> bool { + std::function Pred = + [&](const Instruction *I) -> ASBool { // See callsite abstract attribute. if (ImmutableCallSite ICS = ImmutableCallSite(I)) for (unsigned int i = 0; i < ICS.getNumArgOperands(); i++) if (ICS.getArgOperand(i) == &getAnchoredValue()) if (auto *NonNullAA = A.getAAFor(*this, *I, i)) // FIXME: Use assumption. - return NonNullAA->isKnownNonNull(); + return ASBool(NonNullAA->isKnownNonNull()); // See memory instruction with constant offset. // Currently only inbounds GEPs are tracked. @@ -1272,13 +1290,15 @@ int64_t Offset = 0; if (const Value *Base = getBasePointerOfPointerOperand( I, Offset, getAnchorScope().getParent()->getDataLayout())) - return Base == &getAnchoredValue(); + return ASBool(Base == &getAnchoredValue()); - return false; + return ASBool::Bottom(); }; - if (Explorer.checkPredicateAfterInstruction( - Pred, &getAnchorScope().getEntryBlock().front())) { + if (Explorer + .checkPredicateAfterInstruction( + Pred, &getAnchorScope().getEntryBlock().front()) + .State) { indicateOptimisticFixpoint(); return ChangeStatus::CHANGED; } @@ -1956,6 +1976,28 @@ ChangeStatus updateImpl(Attributor &A) override; }; +struct ASDeref : MustBeExecutedContextExplorer::KnownAbstractState { + bool NonNull; + uint32_t DerefBytes; + ASDeref(bool NonNull, uint32_t DerefBytes) + : NonNull(NonNull), DerefBytes(DerefBytes) {} + + bool stopExplore() const override { return false; } + + ASDeref &operator&=(const ASDeref &R) { + NonNull &= R.NonNull; + DerefBytes = std::min(DerefBytes, R.DerefBytes); + return *this; + } + ASDeref &operator|=(const ASDeref &R) { + NonNull |= R.NonNull; + DerefBytes = std::max(DerefBytes, R.DerefBytes); + return *this; + } + static ASDeref Bottom() { return ASDeref(false, 0); } + static ASDeref Top() { return ASDeref(true, 1e9); } +}; + ChangeStatus AADereferenceableArgument::updateImpl(Attributor &A) { Function &F = getAnchorScope(); const DataLayout &DL = F.getParent()->getDataLayout(); @@ -1973,17 +2015,16 @@ MustBeExecutedContextExplorer &Explorer = InfoCache.getContextExplorer(); - auto Pred = [&](const Instruction *I) { + std::function Pred = + [&](const Instruction *I) -> ASDeref { // See callsite abstract attribute. + ASDeref Deref = ASDeref::Bottom(); if (ImmutableCallSite ICS = ImmutableCallSite(I)) for (unsigned int i = 0; i < ICS.getNumArgOperands(); i++) if (ICS.getArgOperand(i) == &getAnchoredValue()) - if (auto *DerefAA = A.getAAFor(*this, *I, i)) { - // FIXME: Use assumption. - auto DerefBytes = DerefAA->getKnownDereferenceableBytes(); - takeKnownDerefBytesMaximum(DerefBytes); - return true; - } + if (auto *DerefAA = A.getAAFor(*this, *I, i)) + Deref |= ASDeref(DerefAA->isKnownNonNull(), + DerefAA->getKnownDereferenceableBytes()); // See memory instruction with constant offset. // Currently only inbounds GEPs are tracked. @@ -1994,16 +2035,17 @@ uint32_t DerefBytes = Offset + DL.getTypeStoreSize(Base->getType()->getPointerElementType()); - addKnownNonnull(); - takeKnownDerefBytesMaximum(DerefBytes); - return true; + Deref |= ASDeref(true, DerefBytes); } - return false; + return Deref; }; - Explorer.checkPredicateAfterInstruction( - Pred, &getAnchorScope().getEntryBlock().front(), - /* ContinueExplorationWhenPredHolds */ true); + ASDeref Res = Explorer.checkPredicateAfterInstruction( + Pred, &getAnchorScope().getEntryBlock().front()); + + if (Res.NonNull) + addKnownNonnull(); + takeKnownDerefBytesMaximum(Res.DerefBytes); // Callback function std::function CallSiteCheck = [&](CallSite CS) -> bool { Index: llvm/test/Transforms/FunctionAttrs/dereferenceable.ll =================================================================== --- llvm/test/Transforms/FunctionAttrs/dereferenceable.ll +++ llvm/test/Transforms/FunctionAttrs/dereferenceable.ll @@ -50,3 +50,79 @@ ret i32* %0 } +declare void @fun0() #0 +declare void @fun1(i8*) #0 +declare void @fun2(i8*, i8*) #0 +declare void @fun3(i8*, i8*, i8*) #0 +; TEST 5 simple path test +; if(..) +; fun2(dereferenceable(8) %a, dereferenceable(8) %b) +; else +; fun2(dereferenceable(4) %a, %b) +; We can say that %a is dereferenceable(4) but %b is not. +define void @f5(i8* %a, i8 * %b, i8 %c) { +; ATTRIBUTOR: define void @f5(i8* nonnull dereferenceable(4) %a, i8* %b, i8 %c) + %cmp = icmp eq i8 %c, 0 + br i1 %cmp, label %if.then, label %if.else +if.then: + tail call void @fun2(i8* dereferenceable(8) %a, i8* dereferenceable(8) %b) + ret void +if.else: + tail call void @fun2(i8* dereferenceable(4) %a, i8* %b) + ret void +} +; TEST 6 explore child BB test +; if(..) +; ... (willreturn & nounwind) +; else +; ... (willreturn & nounwind) +; fun1(dereferenceable(8) %a) +; We can say that %a is dereferenceable(8) +define void @f6(i8* %a, i8 %c) { +; ATTRIBUTOR: define void @f6(i8* nonnull dereferenceable(8) %a, i8 %c) + %cmp = icmp eq i8 %c, 0 + br i1 %cmp, label %if.then, label %if.else +if.then: + tail call void @fun0() + br label %cont +if.else: + tail call void @fun0() + br label %cont +cont: + tail call void @fun1(i8* dereferenceable(8) %a) + ret void +} +; TEST 7 More complex test +; { +; fun1(dereferenceable(4) %a) +; if(..) +; ... (willreturn & nounwind) +; fun1(dereferenceable(12) %a) +; else +; ... (willreturn & nounwind) +; fun1(dereferenceable(16) %a) +; fun1(dereferenceable(8) %a) +; } +; %a is dereferenceable(12) + +define void @f7(i8* %a, i8* %b, i8 %c) { +; ATTRIBUTOR: define void @f7(i8* nonnull dereferenceable(12) %a, i8* %b, i8 %c) + %cmp = icmp eq i8 %c, 0 + tail call void @fun1(i8* dereferenceable(4) %a) + br i1 %cmp, label %cont.then, label %cont.else +cont.then: + tail call void @fun1(i8* dereferenceable(12) %a) + br label %cont2 +cont.else: + tail call void @fun1(i8* dereferenceable(16) %a) + br label %cont2 +cont2: + tail call void @fun1(i8* dereferenceable(8) %a) + ret void +} +define void @f8(i8* %a) { + tail call void @fun1(i8* dereferenceable(4) %a) + ret void +} + +attributes #0 = { nounwind willreturn} Index: llvm/test/Transforms/FunctionAttrs/nonnull.ll =================================================================== --- llvm/test/Transforms/FunctionAttrs/nonnull.ll +++ llvm/test/Transforms/FunctionAttrs/nonnull.ll @@ -239,8 +239,7 @@ ; fun2(nonnull %a, %b) ; We can say that %a is nonnull but %b is not. define void @f16(i8* %a, i8 * %b, i8 %c) { -; FIXME: missing nonnull on %a -; ATTRIBUTOR: define void @f16(i8* %a, i8* %b, i8 %c) +; ATTRIBUTOR: define void @f16(i8* nonnull %a, i8* %b, i8 %c) %cmp = icmp eq i8 %c, 0 br i1 %cmp, label %if.then, label %if.else if.then: @@ -258,8 +257,7 @@ ; fun1(nonnull %a) ; We can say that %a is nonnull define void @f17(i8* %a, i8 %c) { -; FIXME: missing nonnull on %a -; ATTRIBUTOR: define void @f17(i8* %a, i8 %c) +; ATTRIBUTOR: define void @f17(i8* nonnull %a, i8 %c) %cmp = icmp eq i8 %c, 0 br i1 %cmp, label %if.then, label %if.else if.then: @@ -284,8 +282,7 @@ ; fun1(nonnull %a) define void @f18(i8* %a, i8* %b, i8 %c) { -; FIXME: missing nonnull on %a -; ATTRIBUTOR: define void @f18(i8* %a, i8* %b, i8 %c) +; ATTRIBUTOR: define void @f18(i8* nonnull %a, i8* %b, i8 %c) %cmp1 = icmp eq i8 %c, 0 br i1 %cmp1, label %if.then, label %if.else if.then: @@ -311,8 +308,7 @@ ; TEST 19: Loop define void @f19(i8* %a, i8* %b, i8 %c) { -; FIXME: missing nonnull on %b -; ATTRIBUTOR: define void @f19(i8* %a, i8* %b, i8 %c) +; ATTRIBUTOR: define void @f19(i8* %a, i8* nonnull %b, i8 %c) br label %loop.header loop.header: %cmp2 = icmp eq i8 %c, 0