Index: lib/Transforms/Scalar/SCCP.cpp =================================================================== --- lib/Transforms/Scalar/SCCP.cpp +++ lib/Transforms/Scalar/SCCP.cpp @@ -70,6 +70,8 @@ STATISTIC(IPNumInstRemoved, "Number of instructions removed by IPSCCP"); STATISTIC(IPNumArgsElimed ,"Number of arguments constant propagated by IPSCCP"); STATISTIC(IPNumGlobalConst, "Number of globals found to be constant by IPSCCP"); +STATISTIC(IPNumNonNullAdded, + "Number of globals found to be constant by IPSCCP"); namespace { @@ -90,6 +92,9 @@ /// asserting. forcedconstant, + /// This Value is known to not have the specified value. + notconstant, + /// overdefined - This instruction is not known to be constant, and we know /// it has a value. overdefined @@ -97,7 +102,7 @@ /// Val: This stores the current lattice value along with the Constant* for /// the constant if this is a 'constant' or 'forcedconstant' value. - PointerIntPair Val; + PointerIntPair Val; LatticeValueTy getLatticeValue() const { return Val.getInt(); @@ -112,6 +117,8 @@ return getLatticeValue() == constant || getLatticeValue() == forcedconstant; } + bool isNotConstant() const { return getLatticeValue() == notconstant; } + bool isOverdefined() const { return getLatticeValue() == overdefined; } Constant *getConstant() const { @@ -153,6 +160,19 @@ return true; } + bool markNotConstant(Constant *V) { + assert(V && "Marking constant with NULL"); + if (isa(V)) + return false; + + assert((!isNotConstant() || getNotConstant() == V) && + "Marking !constant with different value"); + assert(isUnknown() || isNotConstant()); + Val.setPointer(V); + Val.setInt(notconstant); + return true; + } + /// getConstantInt - If this is a constant with a ConstantInt value, return it /// otherwise return null. ConstantInt *getConstantInt() const { @@ -161,6 +181,11 @@ return nullptr; } + Constant *getNotConstant() const { + assert(isNotConstant() && "Cannot get the constant of a non-notconstant!"); + return Val.getPointer(); + } + /// getBlockAddress - If this is a constant with a BlockAddress value, return /// it, otherwise return null. BlockAddress *getBlockAddress() const { @@ -180,6 +205,8 @@ return ValueLatticeElement::getOverdefined(); if (isConstant()) return ValueLatticeElement::get(getConstant()); + if (isNotConstant()) + return ValueLatticeElement::getNot(getNotConstant()); return ValueLatticeElement(); } }; @@ -443,6 +470,14 @@ pushToWorkList(IV, V); } + bool markNotConstant(LatticeVal &IV, Value *V, Constant *C) { + if (!IV.markNotConstant(C)) + return false; + LLVM_DEBUG(dbgs() << "markNotConstant: " << *C << ": " << *V << '\n'); + pushToWorkList(IV, V); + return true; + } + // markOverdefined - Make a value be marked as "overdefined". If the // value is not already overdefined, add it to the overdefined instruction // work list so that the users of the instruction are updated later. @@ -463,10 +498,30 @@ return false; // Noop. if (MergeWithV.isOverdefined()) return markOverdefined(IV, V); - if (IV.isUnknown()) - return markConstant(IV, V, MergeWithV.getConstant()); - if (IV.getConstant() != MergeWithV.getConstant()) - return markOverdefined(IV, V); + + if (IV.isUnknown()) { + // When merging an unknown value with a != X, we cannot conclude + // != X, because the unknown value could be undef and we do not replace + // notconstants after solving. So the undef value could take X later on. + if (MergeWithV.isNotConstant()) + return markOverdefined(IV, V); + else { + assert(MergeWithV.isConstant() && "Unexpected lattice value type"); + return markConstant(IV, V, MergeWithV.getConstant()); + } + } + + if (IV.isNotConstant()) { + if (!MergeWithV.isNotConstant() || + IV.getNotConstant() != MergeWithV.getNotConstant()) + return markOverdefined(IV, V); + return false; + } + + if (IV.isConstant()) + if (!MergeWithV.isConstant() || + IV.getConstant() != MergeWithV.getConstant()) + return markOverdefined(IV, V); return false; } @@ -788,6 +843,9 @@ if (IV.isOverdefined()) // PHI node becomes overdefined! return (void)markOverdefined(&PN); + if (IV.isNotConstant()) + return (void)markOverdefined(&PN); + if (!OperandVal) { // Grab the first value. OperandVal = IV.getConstant(); continue; @@ -851,7 +909,8 @@ void SCCPSolver::visitCastInst(CastInst &I) { LatticeVal OpSt = getValueState(I.getOperand(0)); - if (OpSt.isOverdefined()) // Inherit overdefinedness of operand + if (OpSt.isOverdefined() || + OpSt.isNotConstant()) // Inherit overdefinedness of operand markOverdefined(&I); else if (OpSt.isConstant()) { // Fold the constant as we build. @@ -960,6 +1019,10 @@ LatticeVal &IV = ValueState[&I]; if (IV.isOverdefined()) return; + // TODO: could do better than that in some cases. + if (V1State.isNotConstant() || V2State.isNotConstant()) + return (void)markOverdefined(&I); + if (V1State.isConstant() && V2State.isConstant()) { Constant *C = ConstantExpr::get(I.getOpcode(), V1State.getConstant(), V2State.getConstant()); @@ -988,9 +1051,9 @@ if (I.getOpcode() == Instruction::And || I.getOpcode() == Instruction::Mul || I.getOpcode() == Instruction::Or) { LatticeVal *NonOverdefVal = nullptr; - if (!V1State.isOverdefined()) + if (!V1State.isOverdefined() && !V1State.isNotConstant()) NonOverdefVal = &V1State; - else if (!V2State.isOverdefined()) + else if (!V2State.isOverdefined() && !V2State.isNotConstant()) NonOverdefVal = &V2State; if (NonOverdefVal) { @@ -1046,7 +1109,7 @@ } // If operands are still unknown, wait for it to resolve. - if (!V1State.isOverdefined() && !V2State.isOverdefined() && !IV.isConstant()) + if (!V1State.isOverdefined() && !V2State.isOverdefined() && IV.isUnknown()) return; markOverdefined(&I); @@ -1065,7 +1128,7 @@ if (State.isUnknown()) return; // Operands are not resolved yet. - if (State.isOverdefined()) + if (State.isOverdefined() || State.isNotConstant()) return (void)markOverdefined(&I); assert(State.isConstant() && "Unknown state!"); @@ -1192,20 +1255,30 @@ LatticeVal OriginalVal = getValueState(CopyOf); LatticeVal EqVal = getValueState(CmpOp1); LatticeVal &IV = ValueState[I]; - if (PBranch->TrueEdge && Cmp->getPredicate() == CmpInst::ICMP_EQ) { - addAdditionalUser(CmpOp1, I); - if (OriginalVal.isConstant()) - mergeInValue(IV, I, OriginalVal); - else - mergeInValue(IV, I, EqVal); + if (OriginalVal.isConstant()) { + mergeInValue(IV, I, OriginalVal); return; } + + if (Cmp->getPredicate() == CmpInst::ICMP_EQ) { + if (PBranch->TrueEdge) { + addAdditionalUser(CmpOp1, I); + mergeInValue(IV, I, EqVal); + return; + } else { + addAdditionalUser(CmpOp1, I); + if (EqVal.isConstant() && + (IV.isUnknown() || + (IV.isNotConstant() && + IV.getNotConstant() == EqVal.getConstant()))) { + markNotConstant(IV, I, EqVal.getConstant()); + return; + } + } + } if (!PBranch->TrueEdge && Cmp->getPredicate() == CmpInst::ICMP_NE) { addAdditionalUser(CmpOp1, I); - if (OriginalVal.isConstant()) - mergeInValue(IV, I, OriginalVal); - else - mergeInValue(IV, I, EqVal); + mergeInValue(IV, I, EqVal); return; } @@ -1693,12 +1766,40 @@ return false; } +static bool tryToAddNonNull(SCCPSolver &Solver, Instruction &I) { + CallInst *CI = dyn_cast(&I); + + if (!CI) + return false; + + CallSite CS(CI); + + unsigned ArgNo = 0; + for (auto &A : CS.args()) { + if (!A->getType()->isPointerTy() || A->getType()->isStructTy() || + !isa(&*A)) + continue; + + const LatticeVal &IV = Solver.getLatticeValueFor(&*A); + if (IV.isNotConstant() && + (IV.getNotConstant()->isNullValue() || + (IV.isConstant() && !IV.getConstant()->isNullValue()))) { + IPNumNonNullAdded++; + CS.addParamAttr(ArgNo, Attribute::NonNull); + } + ArgNo++; + } + + return true; +} + static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { Constant *Const = nullptr; if (V->getType()->isStructTy()) { std::vector IVs = Solver.getStructLatticeValueFor(V); - if (llvm::any_of(IVs, - [](const LatticeVal &LV) { return LV.isOverdefined(); })) + if (llvm::any_of(IVs, [](const LatticeVal &LV) { + return LV.isOverdefined() || LV.isNotConstant(); + })) return false; std::vector ConstVals; auto *ST = dyn_cast(V->getType()); @@ -1711,7 +1812,7 @@ Const = ConstantStruct::get(ST, ConstVals); } else { const LatticeVal &IV = Solver.getLatticeValueFor(V); - if (IV.isOverdefined()) + if (IV.isOverdefined() || IV.isNotConstant()) return false; Const = IV.isConstant() ? IV.getConstant() : UndefValue::get(V->getType()); @@ -1979,6 +2080,7 @@ Instruction *Inst = &*BI++; if (Inst->getType()->isVoidTy()) continue; + tryToAddNonNull(Solver, *Inst); if (tryToReplaceWithConstant(Solver, Inst)) { if (Inst->isSafeToRemove()) Inst->eraseFromParent(); Index: test/Transforms/SCCP/ipsccp-nonnull.ll =================================================================== --- /dev/null +++ test/Transforms/SCCP/ipsccp-nonnull.ll @@ -0,0 +1,38 @@ +; RUN: opt < %s -ipsccp -S | FileCheck %s + +; CHECK-LABEL: @test1 +; CHECK: %call = call i1 @testf(i32* nonnull %ptr) +declare i1 @testf(i32* %ptr) + +define i1 @test1(i32* %ptr) { +entry: + %cond = icmp eq i32* %ptr, null + br i1 %cond, label %if.then, label %if.end + +if.then: ; preds = %entry + ret i1 true + +if.end: ; preds = %if.then, %entry + %call = call i1 @testf(i32* %ptr) + ret i1 %call +} + +; We can conclude that %ptr != 0, but we cannot conclude %ptr1 != 0, because we +; cannot replace uses of %ptr1, so %ptr1 could be undef, which in turn could be +; resolved to 0. + +; CHECK-LABEL: @test2 +; CHECK: %call = call i1 @testf(i32* %ptr1) +define i1 @test2(i32* %ptr, i1 %c) { +entry: + %cond = icmp eq i32* %ptr, null + br i1 %cond, label %if.then, label %if.end + +if.then: ; preds = %entry + ret i1 true + +if.end: ; preds = %if.then, %entry + %ptr1 = select i1 %c, i32* undef, i32* %ptr + %call = call i1 @testf(i32* %ptr1) + ret i1 %call +}