Index: lib/Transforms/Scalar/CallSiteSplitting.cpp =================================================================== --- lib/Transforms/Scalar/CallSiteSplitting.cpp +++ lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -72,10 +72,8 @@ STATISTIC(NumCallSiteSplit, "Number of call-site split"); -static void addNonNullAttribute(Instruction *CallI, Instruction *&NewCallI, +static void addNonNullAttribute(Instruction *CallI, Instruction *NewCallI, Value *Op) { - if (!NewCallI) - NewCallI = CallI->clone(); CallSite CS(NewCallI); unsigned ArgNo = 0; for (auto &I : CS.args()) { @@ -85,10 +83,8 @@ } } -static void setConstantInArgument(Instruction *CallI, Instruction *&NewCallI, +static void setConstantInArgument(Instruction *CallI, Instruction *NewCallI, Value *Op, Constant *ConstValue) { - if (!NewCallI) - NewCallI = CallI->clone(); CallSite CS(NewCallI); unsigned ArgNo = 0; for (auto &I : CS.args()) { @@ -114,99 +110,58 @@ return false; } -static SmallVector -findOrCondRelevantToCallArgument(CallSite CS) { - SmallVector BranchInsts; - for (auto PredBB : predecessors(CS.getInstruction()->getParent())) { - auto *PBI = dyn_cast(PredBB->getTerminator()); - if (!PBI || !PBI->isConditional()) - continue; +/// If From has a conditional jump to To, add the condition to Conditions, +/// if it is relevant to any argument at CS. +void recordCondition(const CallSite &CS, BasicBlock *From, BasicBlock *To, SmallVectorImpl> &Conditions) { + auto *BI = dyn_cast(From->getTerminator()); + if (!BI || !BI->isConditional()) + return; + + CmpInst::Predicate Pred; + Value *Cond = BI->getCondition(); + if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant()))) + return; + + ICmpInst *Cmp = cast(Cond); + if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) + if (isCondRelevantToAnyCallArgument(Cmp, CS)) + Conditions.push_back({Cmp, From->getTerminator()->getSuccessor(0) == To ? Pred : Cmp->getInversePredicate()}); +} - CmpInst::Predicate Pred; - Value *Cond = PBI->getCondition(); - if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant()))) - continue; - ICmpInst *Cmp = cast(Cond); - if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) - if (isCondRelevantToAnyCallArgument(Cmp, CS)) - BranchInsts.push_back(PBI); +/// Record ICmp conditions relevant to any argument in CS following Pred's +/// single successors. +void recordConditions(const CallSite &CS, BasicBlock *Pred, SmallVectorImpl> &Conditions) { + recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions); + BasicBlock *From = Pred; + BasicBlock *To = Pred; + while ((From = From->getSinglePredecessor())) { + recordCondition(CS, From, To, Conditions); + To = From; } - return BranchInsts; } -static bool tryCreateCallSitesOnOrPredicatedArgument( - CallSite CS, Instruction *&NewCSTakenFromHeader, - Instruction *&NewCSTakenFromNextCond, BasicBlock *HeaderBB) { - auto BranchInsts = findOrCondRelevantToCallArgument(CS); - assert(BranchInsts.size() <= 2 && - "Unexpected number of blocks in the OR predicated condition"); - Instruction *Instr = CS.getInstruction(); - BasicBlock *CallSiteBB = Instr->getParent(); - TerminatorInst *HeaderTI = HeaderBB->getTerminator(); - bool IsCSInTakenPath = CallSiteBB == HeaderTI->getSuccessor(0); - - for (auto *PBI : BranchInsts) { - assert(isa(PBI->getCondition()) && - "Unexpected condition in a conditional branch."); - ICmpInst *Cmp = cast(PBI->getCondition()); - Value *Arg = Cmp->getOperand(0); - assert(isa(Cmp->getOperand(1)) && - "Expected op1 to be a constant."); - Constant *ConstVal = cast(Cmp->getOperand(1)); - CmpInst::Predicate Pred = Cmp->getPredicate(); - - if (PBI->getParent() == HeaderBB) { - Instruction *&CallTakenFromHeader = - IsCSInTakenPath ? NewCSTakenFromHeader : NewCSTakenFromNextCond; - Instruction *&CallUntakenFromHeader = - IsCSInTakenPath ? NewCSTakenFromNextCond : NewCSTakenFromHeader; - - assert((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) && - "Unexpected predicate in an OR condition"); - - // Set the constant value for agruments in the call predicated based on - // the OR condition. - Instruction *&CallToSetConst = Pred == ICmpInst::ICMP_EQ - ? CallTakenFromHeader - : CallUntakenFromHeader; - setConstantInArgument(Instr, CallToSetConst, Arg, ConstVal); - - // Add the NonNull attribute if compared with the null pointer. - if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) { - Instruction *&CallToSetAttr = Pred == ICmpInst::ICMP_EQ - ? CallUntakenFromHeader - : CallTakenFromHeader; - addNonNullAttribute(Instr, CallToSetAttr, Arg); - } - continue; - } - - if (Pred == ICmpInst::ICMP_EQ) { - if (PBI->getSuccessor(0) == Instr->getParent()) { - // Set the constant value for the call taken from the second block in - // the OR condition. - setConstantInArgument(Instr, NewCSTakenFromNextCond, Arg, ConstVal); - } else { - // Add the NonNull attribute if compared with the null pointer for the - // call taken from the second block in the OR condition. - if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) - addNonNullAttribute(Instr, NewCSTakenFromNextCond, Arg); - } - } else { - if (PBI->getSuccessor(0) == Instr->getParent()) { - // Add the NonNull attribute if compared with the null pointer for the - // call taken from the second block in the OR condition. - if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) - addNonNullAttribute(Instr, NewCSTakenFromNextCond, Arg); - } else if (Pred == ICmpInst::ICMP_NE) { - // Set the constant value for the call in the untaken path from the - // header block. - setConstantInArgument(Instr, NewCSTakenFromNextCond, Arg, ConstVal); - } else - llvm_unreachable("Unexpected condition"); +Instruction *addConditions(CallSite &CS, SmallVectorImpl> &Conditions) { + if (Conditions.empty()) + return nullptr; + + Instruction *NewCI = CS.getInstruction()->clone(); + for (auto &Cond : Conditions) { + Value *Arg = Cond.first->getOperand(0); + Constant *ConstVal = cast(Cond.first->getOperand(1)); + if (Cond.second == ICmpInst::ICMP_EQ) + setConstantInArgument(CS.getInstruction(), NewCI, Arg, ConstVal); + else if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) { + assert(Cond.second == ICmpInst::ICMP_NE); + addNonNullAttribute(CS.getInstruction(), NewCI, Arg); } } - return NewCSTakenFromHeader || NewCSTakenFromNextCond; + return NewCI; +} + +static SmallVector getTwoPredecessors(BasicBlock *BB) { + SmallVector Preds(predecessors((BB))); + assert(Preds.size() == 2 && "Expected exactly 2 predecessors!"); + return Preds; } static bool canSplitCallSite(CallSite CS) { @@ -234,6 +189,7 @@ } /// Return true if the CS is split into its new predecessors which are directly + /// hooked to each of its orignial predecessors pointed by PredBB1 and PredBB2. /// In OR predicated case, PredBB1 will point the header, and PredBB2 will point /// to the second compare block. CallInst1 and CallInst2 will be the new @@ -357,12 +313,6 @@ return false; } -static SmallVector getTwoPredecessors(BasicBlock *BB) { - SmallVector Preds(predecessors((BB))); - assert(Preds.size() == 2 && "Expected exactly 2 predecessors!"); - return Preds; -} - static bool tryToSplitOnPHIPredicatedArgument(CallSite CS) { if (!isPredicatedOnPHI(CS)) return false; @@ -382,26 +332,19 @@ static bool tryToSplitOnOrPredicatedArgument(CallSite CS) { auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); - BasicBlock *HeaderBB = nullptr; - BasicBlock *OrBB = nullptr; - if (isOrHeader(Preds[0], Preds[1])) { - HeaderBB = Preds[0]; - OrBB = Preds[1]; - } else if (isOrHeader(Preds[1], Preds[0])) { - HeaderBB = Preds[1]; - OrBB = Preds[0]; - } else + if (!isOrHeader(Preds[0], Preds[1]) && !isOrHeader(Preds[1], Preds[0])) return false; - Instruction *CallInst1 = nullptr; - Instruction *CallInst2 = nullptr; - if (!tryCreateCallSitesOnOrPredicatedArgument(CS, CallInst1, CallInst2, - HeaderBB)) { - assert(!CallInst1 && !CallInst2 && "Unexpected new call-sites cloned."); + SmallVector, 2> C1, C2; + recordConditions(CS, Preds[0], C1); + recordConditions(CS, Preds[1], C2); + + Instruction *CallInst1 = addConditions(CS, C1); + Instruction *CallInst2 = addConditions(CS, C2); + if (!CallInst1 && !CallInst2) return false; - } - splitCallSite(CS, HeaderBB, OrBB, CallInst1, CallInst2); + splitCallSite(CS, Preds[1], Preds[0], CallInst2, CallInst1); return true; } Index: test/Transforms/CallSiteSplitting/callsite-split-or-phi.ll =================================================================== --- test/Transforms/CallSiteSplitting/callsite-split-or-phi.ll +++ test/Transforms/CallSiteSplitting/callsite-split-or-phi.ll @@ -342,6 +342,36 @@ ret i32 %v } + +;CHECK-LABEL: @test_eq_eq_eq_untaken +;CHECK-LABEL: Tail.predBB1.split: +;CHECK: %[[CALL1:.*]] = call i32 @callee(i32* nonnull %a, i32 %v, i32 10) +;CHECK-LABEL: Tail.predBB2.split: +;CHECK: %[[CALL2:.*]] = call i32 @callee(i32* nonnull %a, i32 1, i32 %p) +;CHECK-LABEL: Tail +;CHECK: %[[MERGED:.*]] = phi i32 [ %[[CALL1]], %Tail.predBB1.split ], [ %[[CALL2]], %Tail.predBB2.split ] +;CHECK: ret i32 %[[MERGED]] +define i32 @test_eq_eq_eq_untaken(i32* %a, i32 %v, i32 %p) { +Header: + %tobool1 = icmp eq i32* %a, null + br i1 %tobool1, label %End, label %Header2 + +Header2: + %tobool2 = icmp eq i32 %p, 10 + br i1 %tobool2, label %Tail, label %TBB + +TBB: + %cmp = icmp eq i32 %v, 1 + br i1 %cmp, label %Tail, label %End + +Tail: + %r = call i32 @callee(i32* %a, i32 %v, i32 %p) + ret i32 %r + +End: + ret i32 %v +} + define i32 @callee(i32* %a, i32 %v, i32 %p) { entry: %c = icmp ne i32* %a, null