Index: lib/Transforms/IPO/Inliner.cpp =================================================================== --- lib/Transforms/IPO/Inliner.cpp +++ lib/Transforms/IPO/Inliner.cpp @@ -32,12 +32,15 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; +using namespace PatternMatch; #define DEBUG_TYPE "inline" @@ -335,6 +338,220 @@ return false; } +static void addNonNullAttribute(Instruction *CallI, Instruction *&NewCallI, + Value *Op, Constant *ConstValue) { + if (!NewCallI) { + NewCallI = CallI->clone(); + NewCallI->insertAfter(CallI); + } + CallSite CS(NewCallI); + unsigned ArgNo = 0; + for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E; + ++I, ++ArgNo) + if (*I == Op) + CS.addParamAttr(ArgNo, Attribute::NonNull); +} + +static void setConstantInArgument(Instruction *CallI, Instruction *&NewCallI, + Value *Op, Constant *ConstValue) { + if (!NewCallI) { + NewCallI = CallI->clone(); + NewCallI->insertAfter(CallI); + } + CallSite CS(NewCallI); + unsigned ArgNo = 0; + for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E; + ++I, ++ArgNo) + if (*I == Op) + CS.setArgument(ArgNo, ConstValue); +} + +static bool createCallSitesWithConstrainedArgument( + Instruction *Instr, Instruction *&TopTakenCI, Instruction *&TopUntakenCI, + SmallVectorImpl &BranchInsts, BasicBlock *TopBB) { + assert(BranchInsts.size() <= 2 && + "Unexpected number of blocks in the OR predicated condition"); + BasicBlock *CallSiteBB = Instr->getParent(); + TerminatorInst *TopTI = TopBB->getTerminator(); + bool IsCSInTakenPath = CallSiteBB == TopTI->getSuccessor(0); + + for (unsigned I = 0, E = BranchInsts.size(); I != E; ++I) { + BranchInst *PBI = BranchInsts[I]; + assert(PBI->isConditional()); + ICmpInst *Cmp = cast(PBI->getCondition()); + Value *Op0 = Cmp->getOperand(0); + Constant *Op1 = cast(Cmp->getOperand(1)); + CmpInst::Predicate Pred = Cmp->getPredicate(); + + if (PBI->getParent() == TopBB) { + Instruction *&CallTakenFromTop = IsCSInTakenPath ? TopTakenCI : TopUntakenCI; + Instruction *&CallUntakenFromTop = IsCSInTakenPath ? TopUntakenCI : TopTakenCI; + + assert(Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE && + "Unexpected predicate in an OR condition"); + + // Set the constant value for the call in the taken path from the top + // block. + Instruction *&CallTaken = ICmpInst::ICMP_EQ ? CallTakenFromTop : CallUntakenFromTop; + setConstantInArgument(Instr, CallTaken, Op0, Op1); + + // Add the NonNull attribute if compared with the null pointer for the + // call in the untaken path from the top block. + if (Op1->getType()->isPointerTy() && Op1->isNullValue()) { + Instruction *&CallUntaken = ICmpInst::ICMP_EQ ? CallUntakenFromTop : CallTakenFromTop; + addNonNullAttribute(Instr, CallUntaken, Op0, Op1); + } + + } else { + Instruction *&CallUntaken = TopUntakenCI; + if (Pred == ICmpInst::ICMP_EQ) { + if (PBI->getSuccessor(0) == Instr->getParent()) { + // Set the constant value for the call in the untaken path from the + // top block. + setConstantInArgument(Instr, CallUntaken, Op0, Op1); + } else { + // Add the NonNull attribute if compared with the null pointer for the + // call in the untaken path from the top block. + if (Op1->getType()->isPointerTy() && Op1->isNullValue()) + addNonNullAttribute(Instr, CallUntaken, Op0, Op1); + } + + } else { + if (PBI->getSuccessor(0) == Instr->getParent()) { + // Add the NonNull attribute if compared with the null pointer for the + // call in the untaken path from the top block. + if (Op1->getType()->isPointerTy() && Op1->isNullValue()) + addNonNullAttribute(Instr, CallUntaken, Op0, Op1); + } else if (Pred == ICmpInst::ICMP_NE) { + // Set the constant value for the call in the untaken path from the + // top block. + setConstantInArgument(Instr, CallUntaken, Op0, Op1); + } else + llvm_unreachable("Unexpected condition"); + } + } + } + return TopTakenCI || TopUntakenCI; +} + +static bool splitOrConds(CallGraph &CG, CallSite CS, BasicBlock *TopBB, + Instruction *CallTaken, Instruction *CallUntaken) { + assert((CallTaken || CallUntaken) && "Expect at least one new call site"); + Instruction *Instr = CS.getInstruction(); + Function *Caller = CS.getCaller(); + Function *Callee = CS.getCalledFunction(); + + BasicBlock *CallSiteBB = Instr->getParent(); + pred_iterator PII = pred_begin(CallSiteBB); + BasicBlock *Pred1 = *PII++; + BasicBlock *Pred2 = *PII; + + BasicBlock *NextCond; + if (TopBB == Pred1) + NextCond = Pred2; + else if (TopBB == Pred2) + NextCond = Pred1; + else + llvm_unreachable("Unexpected OR condition"); + + BasicBlock *TakenBlock = + SplitBlockPredecessors(CallSiteBB, TopBB, ".taken.split"); + BasicBlock *UntakenBlock = + SplitBlockPredecessors(CallSiteBB, NextCond, ".untaken.split"); + if (!TakenBlock || !UntakenBlock) + return false; + + if (!CallTaken) { + CallTaken = Instr->clone(); + CallTaken->insertBefore(&*TakenBlock->getFirstInsertionPt()); + } else + CallTaken->moveBefore(&*TakenBlock->getFirstInsertionPt()); + + if (!CallUntaken) { + CallUntaken = Instr->clone(); + CallUntaken->insertBefore(&*UntakenBlock->getFirstInsertionPt()); + } else + CallUntaken->moveBefore(&*UntakenBlock->getFirstInsertionPt()); + + CallSite CSTaken(CallTaken); + CallSite CSUntaken(CallUntaken); + + CG[Caller]->addCalledFunction(CSTaken, CG[Callee]); + CG[Caller]->addCalledFunction(CSUntaken, CG[Callee]); + CG[Caller]->removeCallEdgeFor(CS); + + // Replace users of the original call with a PHI mering call sites split. + if (Instr->getNumUses()) { + PHINode *PN = PHINode::Create(Instr->getType(), 2, "call.phi", Instr); + PN->addIncoming(CallTaken, TakenBlock); + PN->addIncoming(CallUntaken, UntakenBlock); + Instr->replaceAllUsesWith(PN); + } + Instr->eraseFromParent(); + return true; +} + +static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) { + assert(isa(Cmp->getOperand(1)) && "Expected a constant operand."); + Value *Op0 = Cmp->getOperand(0); + unsigned ArgNo = 0; + for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E; + ++I, ++ArgNo) { + // Don't consider arguments that are already known non-null. + if (CS.paramHasAttr(ArgNo, Attribute::NonNull)) + continue; + + if (*I == Op0) + return true; + } + return false; +} + +static void findOrCondRelevantToCallArgument( + CallSite CS, BasicBlock *PredBB, BasicBlock *OtherPredBB, + SmallVectorImpl &BranchInsts, BasicBlock *&TopBB) { + auto *PBI = dyn_cast(PredBB->getTerminator()); + if (!PBI || !PBI->isConditional()) + return; + + if (OtherPredBB) + if (PBI->getSuccessor(0) == OtherPredBB || + PBI->getSuccessor(1) == OtherPredBB) + if (PredBB == OtherPredBB->getSinglePredecessor()) { + assert(TopBB == nullptr && "Expect to find only a single top block"); + TopBB = PredBB; + } + + CmpInst::Predicate Pred; + Value *Cond = PBI->getCondition(); + if (match(Cond, m_ICmp(Pred, m_Value(), m_Constant()))) { + ICmpInst *Cmp = cast(Cond); + if (isCondRelevantToAnyCallArgument(Cmp, CS)) + if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) + BranchInsts.push_back(PBI); + } +} + +// Return true if an agument in CS is predicated on an 'or' condition. +static bool +isPredicatedOnOrCondition(CallSite CS, + SmallVectorImpl &BranchInsts, + BasicBlock *&TopBB) { + BasicBlock *ParentBB = CS.getInstruction()->getParent(); + + // Multiple predecessors that equal and 'or' condition. + pred_iterator PII = pred_begin(ParentBB); + pred_iterator PIE = pred_end(ParentBB); + unsigned NumPreds = std::distance(PII, PIE); + if (NumPreds != 2) + return false; + + BasicBlock *Preds[2] = {*PII++, *PII}; + findOrCondRelevantToCallArgument(CS, Preds[0], Preds[1], BranchInsts, TopBB); + findOrCondRelevantToCallArgument(CS, Preds[1], Preds[0], BranchInsts, TopBB); + return !BranchInsts.empty() && TopBB != nullptr; +} + /// Return the cost only if the inliner should attempt to inline at the given /// CallSite. If we return the cost, we will emit an optimisation remark later /// using that cost, so we won't do so from this function. @@ -397,6 +614,105 @@ return IC; } +// If a call site is dominated by an OR condition and if any of its arguments +// are predicated on this OR condition, see if splitting the condition (and +// thereby further constraining the arguments) increases our opportunities to +// inline the call. +// +// For example, in the code below, if callee() is not inlinable, we try to +// split the call site since we can predicate the argument (ptr) based on the OR +// condition. Inline if any of the new call sites is inlinable. +// +// Split from : +// if (!ptr || c) +// callee(ptr); +// to : +// if (!ptr) +// callee(nonnull ptr) // set non-null attribute in the argument +// else if (c) +// callee(null) // set the known constant value +// +// , if the inline cost for either callee(null) or callee(nonnull %ptr) is less +// than threshold +static Instruction *tryToInlineIfPredicatedOnOrCondition( + CallGraph &CG, CallSite CS, int &Cost, int &Threshold, + function_ref GetInlineCost, + OptimizationRemarkEmitter &ORE) { + if (!CS.arg_size()) + return nullptr; + + SmallVector BranchInsts; + BasicBlock *TopBB = nullptr; + if (!isPredicatedOnOrCondition(CS, BranchInsts, TopBB)) + return nullptr; + + Instruction *Instr = CS.getInstruction(); + Instruction *CallTaken = nullptr; + Instruction *CallUntaken = nullptr; + + // Based on the OR predicated condition, temporarily create call sites with + // the NonNull attribute or constant value in arguments. + if (!createCallSitesWithConstrainedArgument(Instr, CallTaken, CallUntaken, + BranchInsts, TopBB)) + return nullptr; + + int CostOfTaken = INT_MAX; + int CostOfUntaken = INT_MAX; + int ThresholdOfTaken = INT_MIN; + int ThresholdOfUntaken = INT_MIN; + + if (CallTaken) { + CallSite CSTaken(CallTaken); + Optional OICTaken = shouldInline(CSTaken, GetInlineCost, ORE); + if (OICTaken) { + CostOfTaken = OICTaken->getCost(); + ThresholdOfTaken = OICTaken->getThreshold(); + } + } + + if (CallUntaken) { + CallSite CSUntaken(CallUntaken); + Optional OICUntaken = + shouldInline(CSUntaken, GetInlineCost, ORE); + if (OICUntaken) { + CostOfUntaken = OICUntaken->getCost(); + ThresholdOfUntaken = OICUntaken->getThreshold(); + } + } + + // See if any new call site created above is turned into inlinable. + if (CostOfTaken != INT_MAX || CostOfUntaken != INT_MAX) { + // Allow splitting the OR condition only when the call instruction is the + // first instruction of its block. Based on this constraint, we clone + // only the call instruction, and also we do not add any extra conditional + // branches. + if (Instr != (&*Instr->getParent()->begin()) || + !splitOrConds(CG, CS, TopBB, CallTaken, CallUntaken)) { + if (CallTaken) + CallTaken->eraseFromParent(); + if (CallUntaken) + CallUntaken->eraseFromParent(); + return nullptr; + } + + if (CostOfTaken > CostOfUntaken) { + Cost = CostOfUntaken; + Threshold = ThresholdOfUntaken; + return CallUntaken; + } else { + Cost = CostOfTaken; + Threshold = ThresholdOfTaken; + return CallTaken; + } + } else { + if (CallTaken) + CallTaken->eraseFromParent(); + if (CallUntaken) + CallUntaken->eraseFromParent(); + } + return nullptr; +} + /// Return true if the specified inline history ID /// indicates an inline history that includes the specified function. static bool InlineHistoryIncludes( @@ -543,10 +859,25 @@ OptimizationRemarkEmitter ORE(Caller); Optional OIC = shouldInline(CS, GetInlineCost, ORE); + bool isAlways = false; + int Cost, Threshold; + // If the policy determines that we should inline this function, // delete the call instead. - if (!OIC) - continue; + if (!OIC) { + if (Instruction *InlinableInst = tryToInlineIfPredicatedOnOrCondition( + CG, CS, Cost, Threshold, GetInlineCost, ORE)) { + CallSite InlinableCS(InlinableInst); + CS = InlinableCS; + } else + continue; + } else { + isAlways = OIC->isAlways(); + if (OIC->isVariable()) { + Cost = OIC->getCost(); + Threshold = OIC->getThreshold(); + } + } // If this call site is dead and it is to a readonly function, we should // just delete the call instead of trying to inline it, regardless of @@ -576,17 +907,15 @@ } ++NumInlined; - if (OIC->isAlways()) + if (isAlways) ORE.emit(OptimizationRemark(DEBUG_TYPE, "AlwaysInline", DLoc, Block) << NV("Callee", Callee) << " inlined into " << NV("Caller", Caller) << " with cost=always"); else ORE.emit(OptimizationRemark(DEBUG_TYPE, "Inlined", DLoc, Block) << NV("Callee", Callee) << " inlined into " - << NV("Caller", Caller) - << " with cost=" << NV("Cost", OIC->getCost()) - << " (threshold=" << NV("Threshold", OIC->getThreshold()) - << ")"); + << NV("Caller", Caller) << " with cost=" << NV("Cost", Cost) + << " (threshold=" << NV("Threshold", Threshold) << ")"); // If inlining this function gave us any new call sites, throw them // onto our worklist to process. They are useful inline candidates. Index: test/Transforms/Inline/inline-predicated-or.ll =================================================================== --- /dev/null +++ test/Transforms/Inline/inline-predicated-or.ll @@ -0,0 +1,67 @@ +; RUN: opt < %s -inline -instcombine -jump-threading -S | FileCheck %s + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" +target triple = "aarch64-linaro-linux-gnueabi" + +%struct.bitmap = type { i32, %struct.bitmap* } + +;CHECK-LABEL: @caller +;CHECK-LABEL: NextCond: +;CHECK: br {{.*}} label %callee.exit +;CHECK-LABEL: CallSiteBB.taken.split: +;CHECK: call void @callee(%struct.bitmap* null, %struct.bitmap* null, %struct.bitmap* %b_elt) +;CHECK-LABEL: callee.exit: +;CHECK: call void @dummy2(%struct.bitmap* %a_elt) + +define void @caller(i1 %c, %struct.bitmap* %a_elt, %struct.bitmap* %b_elt) { +entry: + br label %Top + +Top: + %tobool1 = icmp eq %struct.bitmap* %a_elt, null + br i1 %tobool1, label %CallSiteBB, label %NextCond + +NextCond: + %cmp = icmp ne %struct.bitmap* %b_elt, null + br i1 %cmp, label %CallSiteBB, label %End + +CallSiteBB: + call void @callee(%struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %b_elt) + br label %End + +End: + ret void +} + +define void @callee(%struct.bitmap* %dst_elt, %struct.bitmap* %a_elt, %struct.bitmap* %b_elt) { +entry: + %tobool = icmp ne %struct.bitmap* %a_elt, null + %tobool1 = icmp ne %struct.bitmap* %b_elt, null + %or.cond = and i1 %tobool, %tobool1 + br i1 %or.cond, label %Cond, label %Big + +Cond: + %cmp = icmp eq %struct.bitmap* %dst_elt, %a_elt + br i1 %cmp, label %Small, label %Big + +Small: + call void @dummy2(%struct.bitmap* %a_elt) + br label %End + +Big: + call void @dummy1(%struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt) + call void @dummy1(%struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt) + call void @dummy1(%struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt) + call void @dummy1(%struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt) + call void @dummy1(%struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt) + call void @dummy1(%struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt) + call void @dummy1(%struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt, %struct.bitmap* %a_elt) + br label %End + +End: + ret void +} + +declare void @dummy2(%struct.bitmap*) +declare void @dummy1(%struct.bitmap*, %struct.bitmap*, %struct.bitmap*, %struct.bitmap*, %struct.bitmap*, %struct.bitmap*) +