Index: include/llvm/Analysis/ValueLattice.h =================================================================== --- include/llvm/Analysis/ValueLattice.h +++ include/llvm/Analysis/ValueLattice.h @@ -299,6 +299,11 @@ if (isConstant() && Other.isConstant()) return ConstantExpr::getCompare(Pred, getConstant(), Other.getConstant()); + if (isNotConstant() && Other.isConstant()) { + if (Pred == CmpInst::ICMP_EQ && getNotConstant() == Other.getConstant()) + return ConstantInt::getFalse(Ty); + } + // Integer constants are represented as ConstantRanges with single // elements. if (!isConstantRange() || !Other.isConstantRange()) Index: include/llvm/Transforms/Utils/PredicateInfo.h =================================================================== --- include/llvm/Transforms/Utils/PredicateInfo.h +++ include/llvm/Transforms/Utils/PredicateInfo.h @@ -93,7 +93,7 @@ class LLVMContext; class raw_ostream; -enum PredicateType { PT_Branch, PT_Assume, PT_Switch }; +enum PredicateType { PT_Branch, PT_Assume, PT_Switch, PT_NE }; // Base class for all predicate information we provide. // All of our predicate information has at least a comparison. @@ -118,7 +118,7 @@ Value *Condition; static bool classof(const PredicateBase *PB) { return PB->Type == PT_Assume || PB->Type == PT_Branch || - PB->Type == PT_Switch; + PB->Type == PT_Switch || PB->Type == PT_NE; } protected: @@ -141,6 +141,18 @@ } }; +// Provides predicate information for values we do know to be nonnull. We set +// I to the instruction that implies non-nullness, so we can tell your relative +// postition to it. +class PredicateNonNull : public PredicateWithCondition { +public: + Instruction *I; + PredicateNonNull(Value *Op, Instruction *I) + : PredicateWithCondition(PT_NE, Op, nullptr), I(I) {} + PredicateNonNull() = delete; + static bool classof(const PredicateBase *PB) { return PB->Type == PT_NE; } +}; + // Mixin class for edge predicates. The FROM block is the block where the // predicate originates, and the TO block is the block where the predicate is // valid. Index: lib/Transforms/Scalar/SCCP.cpp =================================================================== --- lib/Transforms/Scalar/SCCP.cpp +++ lib/Transforms/Scalar/SCCP.cpp @@ -1223,6 +1223,23 @@ if (!PI) return; + if (isa(getPredicateInfoFor(I))) { + LatticeVal OriginalVal = getValueState(PI->OriginalOp); + LatticeVal &IV = ValueState[I]; + + if (OriginalVal.isConstant()) { + mergeInValue(IV, I, OriginalVal); + return; + } + + if (IV.isUnknown() || + (IV.isNotConstant() && IV.getNotConstant()->isNullValue())) { + markNotConstant( + IV, I, ConstantPointerNull::get(cast(I->getType()))); + return; + } + } + auto *PBranch = dyn_cast(getPredicateInfoFor(I)); if (!PBranch) { mergeInValue(ValueState[I], I, getValueState(PI->OriginalOp)); Index: lib/Transforms/Utils/PredicateInfo.cpp =================================================================== --- lib/Transforms/Utils/PredicateInfo.cpp +++ lib/Transforms/Utils/PredicateInfo.cpp @@ -184,6 +184,9 @@ if (!VD.U) { assert(VD.PInfo && "No def, no use, and no predicateinfo should not occur"); + if (auto *NE = dyn_cast(VD.PInfo)) + return NE->I; + assert(isa(VD.PInfo) && "Middle of block should only occur for assumes"); return cast(VD.PInfo)->AssumeInst; @@ -461,6 +464,19 @@ SmallPtrSet OpsToRename; for (auto DTN : depth_first(DT.getRootNode())) { BasicBlock *BranchBB = DTN->getBlock(); + for (auto &I : *BranchBB) { + if (auto CS = CallSite(&I)) { + for (unsigned ArgNo = 0; ArgNo < CS.arg_size(); ArgNo++) { + if (CS.paramHasAttr(ArgNo, Attribute::NonNull)) { + auto Arg = CS.getArgOperand(ArgNo); + if (!Arg->getType()->isPointerTy()) + continue; + auto *PA = new PredicateNonNull(Arg, &I); + addInfoFor(OpsToRename, Arg, PA); + } + } + } + } if (auto *BI = dyn_cast(BranchBB->getTerminator())) { if (!BI->isConditional()) continue; @@ -529,10 +545,15 @@ PredicateMap.insert({PIC, ValInfo}); Result.Def = PIC; } else { - auto *PAssume = dyn_cast(ValInfo); - assert(PAssume && - "Should not have gotten here without it being an assume"); - IRBuilder<> B(PAssume->AssumeInst); + Instruction *I; + if (auto *PNE = dyn_cast(ValInfo)) + I = PNE->I; + else { + assert(isa(ValInfo) && + "Should not have gotten here without it being an assume"); + I = cast(ValInfo)->AssumeInst; + } + IRBuilder<> B(I); Function *IF = getCopyDeclaration(F.getParent(), Op->getType()); if (IF->user_begin() == IF->user_end()) CreatedDeclarations.insert(IF); @@ -596,6 +617,16 @@ VD.DFSOut = DomNode->getDFSNumOut(); VD.PInfo = PossibleCopy; OrderedUses.push_back(VD); + } else if (const auto *PAssume = + dyn_cast(PossibleCopy)) { + VD.LocalNum = LN_Middle; + DomTreeNode *DomNode = DT.getNode(PAssume->I->getParent()); + if (!DomNode) + continue; + VD.DFSIn = DomNode->getDFSNumIn(); + VD.DFSOut = DomNode->getDFSNumOut(); + VD.PInfo = PossibleCopy; + OrderedUses.push_back(VD); } else if (isa(PossibleCopy)) { // If we can only do phi uses, we treat it like it's in the branch // block, and handle it specially. We know that it goes last, and only @@ -824,6 +855,10 @@ } else if (const auto *PA = dyn_cast(PI)) { OS << "; assume predicate info {" << " Comparison:" << *PA->Condition << " }\n"; + } else if (isa(PI)) { + OS << "; assume predicate info {" + << " NonNull " + << " }\n"; } } } Index: test/Transforms/SCCP/ipsccp-nonnull.ll =================================================================== --- test/Transforms/SCCP/ipsccp-nonnull.ll +++ test/Transforms/SCCP/ipsccp-nonnull.ll @@ -15,3 +15,47 @@ %call = call i1 @testf(i32* %ptr) ret i1 %call } + +; CHECK-LABEL: @bar2( +; CHECK: %call = call i1 @testf(i32* nonnull %ptr) +; CHECK-NEXT: br label %if.end +define i1 @bar2(i32* %ptr) { +entry: + %call = call i1 @testf(i32* nonnull %ptr) + %cond = icmp eq i32* %ptr, null + br i1 %cond, label %if.then, label %if.end + +if.then: ; preds = %entry + ret i1 true + +if.end: ; preds = %if.then, %entry + ret i1 %call +} + + +; CHECK-LABEL: @bar3( +; CHECK: br i1 %c1, label %bb1, label %bb2 +; CHECK-LABEL: bb1: +; CHECK-NEXT: %call = call i1 @testf(i32* nonnull %ptr) +; CHECK-NEXT: ret i1 %call +; CHECK-LABEL: bb2: +; CHECK-NEXT: %cond = icmp eq i32* %ptr, null +; CHECK-NEXT: br i1 %cond, label %if.then, label %if.end +define i1 @bar3(i1 %c1, i32* %ptr) { +entry: + br i1 %c1, label %bb1, label %bb2 + +bb1: + %call = call i1 @testf(i32* nonnull %ptr) + ret i1 %call + +bb2: ; preds = %entry + %cond = icmp eq i32* %ptr, null + br i1 %cond, label %if.then, label %if.end + +if.then: + ret i1 false + +if.end: + ret i1 true +}