Index: lib/Analysis/IPA/InlineCost.cpp =================================================================== --- lib/Analysis/IPA/InlineCost.cpp +++ lib/Analysis/IPA/InlineCost.cpp @@ -72,6 +72,11 @@ // The called function. Function &F; + // The callsite being analyzed - used to inspect call site specific + // attributes since these can be more precise than the ones on the callee + // itself. + CallSite CS; + int Threshold; int Cost; @@ -162,8 +167,8 @@ public: CallAnalyzer(const TargetTransformInfo &TTI, AssumptionCacheTracker *ACT, - Function &Callee, int Threshold) - : TTI(TTI), ACT(ACT), F(Callee), Threshold(Threshold), Cost(0), + Function &Callee, int Threshold, CallSite CSArg) + : TTI(TTI), ACT(ACT), F(Callee), CS(CSArg), Threshold(Threshold), Cost(0), IsCallerRecursive(false), IsRecursiveCall(false), ExposesReturnsTwice(false), HasDynamicAlloca(false), ContainsNoDuplicateCall(false), HasReturn(false), HasIndirectBr(false), @@ -555,8 +560,26 @@ } // If the comparison is an equality comparison with null, we can simplify it - // for any alloca-derived argument. - if (I.isEquality() && isa(I.getOperand(1))) + // if we know the value (argument) can't be null + if (I.isEquality() && isa(I.getOperand(1))) { + Value *LHS = I.getOperand(0); + // Does the call site or callee have the NonNull attribute set on an + // argument? + if (isa(LHS)) { + unsigned ArgNo = 0; + for (Argument &Arg : F.args()) { + if (&Arg == LHS) break; + ArgNo++; + } + assert(ArgNo < F.arg_size() && "not an argument to the callee?"); + if (CS.paramHasAttr(ArgNo+1, Attribute::NonNull)) { + bool IsNotEqual = I.getPredicate() == CmpInst::ICMP_NE; + SimplifiedValues[&I] = IsNotEqual ? ConstantInt::getTrue(I.getType()) + : ConstantInt::getFalse(I.getType()); + return true; + } + } + // Is this an alloca in the caller? if (isAllocaDerivedArg(I.getOperand(0))) { // We can actually predict the result of comparisons between an // alloca-derived value and null. Note that this fires regardless of @@ -566,7 +589,7 @@ : ConstantInt::getFalse(I.getType()); return true; } - + } // Finally check for SROA candidates in comparisons. Value *SROAArg; DenseMap::iterator CostIt; @@ -841,7 +864,7 @@ // during devirtualization and so we want to give it a hefty bonus for // inlining, but cap that bonus in the event that inlining wouldn't pan // out. Pretend to inline the function, with a custom threshold. - CallAnalyzer CA(TTI, ACT, *F, InlineConstants::IndirectCallThreshold); + CallAnalyzer CA(TTI, ACT, *F, InlineConstants::IndirectCallThreshold, CS); if (CA.analyzeCall(CS)) { // We were able to inline the indirect call! Subtract the cost from the // bonus we want to apply, but don't go below zero. @@ -1444,7 +1467,7 @@ DEBUG(llvm::dbgs() << " Analyzing call of " << Callee->getName() << "...\n"); - CallAnalyzer CA(TTIWP->getTTI(*Callee), ACT, *Callee, Threshold); + CallAnalyzer CA(TTIWP->getTTI(*Callee), ACT, *Callee, Threshold, CS); bool ShouldInline = CA.analyzeCall(CS); DEBUG(CA.dump()); Index: test/Transforms/Inline/nonnull.ll =================================================================== --- /dev/null +++ test/Transforms/Inline/nonnull.ll @@ -0,0 +1,45 @@ +; RUN: opt -S -inline %s | FileCheck %s + +declare void @foo() +declare void @bar() + +define void @callee(i8* %arg) { + %cmp = icmp eq i8* %arg, null + br i1 %cmp, label %expensive, label %done + +; This block is designed to be too expensive to inline. We can only inline +; callee if this block is known to be dead. +expensive: + call void @foo() + call void @foo() + call void @foo() + call void @foo() + call void @foo() + call void @foo() + call void @foo() + call void @foo() + call void @foo() + call void @foo() + ret void + +done: + call void @bar() + ret void +} + +; Positive test - arg is known non null +define void @caller(i8* nonnull %arg) { +; CHECK-LABEL: @caller +; CHECK: call void @bar() + call void @callee(i8* nonnull %arg) + ret void +} + +; Negative test - arg is not known to be non null +define void @caller2(i8* %arg) { +; CHECK-LABEL: @caller2 +; CHECK: call void @callee( + call void @callee(i8* %arg) + ret void +} +