Index: llvm/include/llvm/IR/Instructions.h =================================================================== --- llvm/include/llvm/IR/Instructions.h +++ llvm/include/llvm/IR/Instructions.h @@ -1308,6 +1308,10 @@ static bool classof(const Value *V) { return isa(V) && classof(cast(V)); } + + // Return the intersection set of Pred1 and Pred2. Otherwise, return + // BAD_ICMP_PREDICATE to indicate null set or unknown result. + static Predicate getAndPredicate(Predicate Pred1, Predicate Pred2); }; //===----------------------------------------------------------------------===// Index: llvm/lib/IR/Instructions.cpp =================================================================== --- llvm/lib/IR/Instructions.cpp +++ llvm/lib/IR/Instructions.cpp @@ -3398,6 +3398,77 @@ } } +ICmpInst::Predicate ICmpInst::getAndPredicate(Predicate Pred1, + Predicate Pred2) { + // If the predicates match, return one of them + if (Pred1 == Pred2) + return Pred1; + switch (Pred1) { + default: + break; + case ICMP_EQ: + if (Pred2 == ICMP_UGE || Pred2 == ICMP_ULE || Pred2 == ICMP_SGE || + Pred2 == ICMP_SLE) + return ICMP_EQ; + break; + case ICMP_NE: + switch (Pred2) { + default: + break; + case ICMP_UGE: + return ICMP_UGT; + case ICMP_ULE: + return ICMP_ULT; + case ICMP_SGE: + return ICMP_SGT; + case ICMP_SLE: + return ICMP_SLT; + } + break; + case ICMP_UGT: + if (Pred2 == ICMP_UGE) + return ICMP_UGT; + break; + case ICMP_UGE: + if (Pred2 == ICMP_UGT || Pred2 == ICMP_NE) + return ICMP_UGT; + if (Pred2 == ICMP_EQ || Pred2 == ICMP_ULE) + return ICMP_EQ; + break; + case ICMP_ULT: + if (Pred2 == ICMP_ULE) + return ICMP_ULT; + break; + case ICMP_ULE: + if (Pred2 == ICMP_ULT || Pred2 == ICMP_NE) + return ICMP_ULT; + if (Pred2 == ICMP_EQ || Pred2 == ICMP_UGE) + return ICMP_EQ; + break; + case ICMP_SGT: + if (Pred2 == ICMP_SGE) + return ICMP_SGT; + break; + case ICMP_SGE: + if (Pred2 == ICMP_SGT || Pred2 == ICMP_NE) + return ICMP_SGT; + if (Pred2 == ICMP_EQ || Pred2 == ICMP_SLE) + return ICMP_EQ; + break; + case ICMP_SLT: + if (Pred2 == ICMP_SLE) + return ICMP_SLT; + break; + case ICMP_SLE: + if (Pred2 == ICMP_SLT || Pred2 == ICMP_NE) + return ICMP_SLT; + if (Pred2 == ICMP_EQ || Pred2 == ICMP_SGE) + return ICMP_EQ; + break; + } + return BAD_ICMP_PREDICATE; +} + CmpInst::Predicate CmpInst::getFlippedStrictnessPredicate(Predicate pred) { switch (pred) { default: llvm_unreachable("Unknown or unsupported cmp predicate!"); Index: llvm/lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -194,6 +194,9 @@ bool SimplifyIndirectBr(IndirectBrInst *IBI); bool SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder); bool SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder); + Value *GetCombinedCompareInstruction(IRBuilder<> &Builder, const Value *CmpA, + const Value *CmpB, bool CmpAIsTrue, + bool CmpBIsTrue); bool tryToSimplifyUncondBranchWithICmpInIt(ICmpInst *ICI, IRBuilder<> &Builder); @@ -5821,6 +5824,69 @@ return PredPred; } +/// Return true if the operands of the two compares match. IsSwappedOps is true +/// when the operands match, but are swapped. +static bool isMatchingOps(const Value *ALHS, const Value *ARHS, + const Value *BLHS, const Value *BRHS, + bool &IsSwappedOps) { + + bool IsMatchingOps = (ALHS == BLHS && ARHS == BRHS); + IsSwappedOps = (ALHS == BRHS && ARHS == BLHS); + return IsMatchingOps || IsSwappedOps; +} + +Value *SimplifyCFGOpt::GetCombinedCompareInstruction(IRBuilder<> &Builder, + const Value *CmpA, + const Value *CmpB, + bool CmpAIsTrue, + bool CmpBIsTrue) { + Type *OpTy = CmpA->getType(); + if (OpTy != CmpB->getType()) + return nullptr; + + if (CmpA == CmpB) + return nullptr; + + assert(OpTy->isIntegerTy(1) && "Expected i1 type only!"); + + const ICmpInst *CmpAInst = dyn_cast(CmpA); + const ICmpInst *CmpBInst = dyn_cast(CmpB); + if (!CmpAInst || !CmpBInst) + return nullptr; + + Value *ALHS = CmpAInst->getOperand(0); + Value *ARHS = CmpAInst->getOperand(1); + // The rest of the logic assumes the CmpA condition is true. If that's not + // the case, invert the predicate to make it so. + ICmpInst::Predicate APred = + CmpAIsTrue ? CmpAInst->getPredicate() : CmpAInst->getInversePredicate(); + + Value *BLHS = CmpBInst->getOperand(0); + Value *BRHS = CmpBInst->getOperand(1); + // The rest of the logic assumes the CmpB condition is true. If that's not + // the case, invert the predicate to make it so. + ICmpInst::Predicate BPred = + CmpBIsTrue ? CmpBInst->getPredicate() : CmpBInst->getInversePredicate(); + + // Can we infer anything when the two compares have matching operands? + bool IsSwappedOpsB; + if (isMatchingOps(ALHS, ARHS, BLHS, BRHS, IsSwappedOpsB)) { + if (IsSwappedOpsB) + BPred = ICmpInst::getSwappedPredicate(BPred); + + ICmpInst::Predicate AndPred = ICmpInst::getAndPredicate(APred, BPred); + if (AndPred != CmpInst::BAD_ICMP_PREDICATE) { + // this instruction will follow CmpB's branch rule, if CmpBIsTrue is false, + // invert the AndPred. + if (!CmpBIsTrue) + AndPred = ICmpInst::getInversePredicate(AndPred); + return Builder.CreateICmp(AndPred, ALHS, ARHS); + } + } + + return nullptr; +} + bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { BasicBlock *BB = BI->getParent(); const Function *Fn = BB->getParent(); @@ -5861,9 +5927,38 @@ if (PBI && PBI->isConditional() && PBI->getSuccessor(0) != PBI->getSuccessor(1)) { assert(PBI->getSuccessor(0) == BB || PBI->getSuccessor(1) == BB); - bool CondIsTrue = PBI->getSuccessor(0) == BB; - Optional Implication = isImpliedCondition( - PBI->getCondition(), BI->getCondition(), DL, CondIsTrue); + + Optional Implication = None; + if (BasicBlock *PDom = Dom->getSinglePredecessor()) { + auto *PPBI = dyn_cast_or_null(PDom->getTerminator()); + // If the single dominating prodecessor block has a single dominating + // predecessor block and the two dominating block's condition implies + // BI's condition, we can combine the two conditions to know the + // direction of the BI branch. E.g. a>=b, a!=b => a>b + if (PPBI && PPBI->isConditional() && + PPBI->getSuccessor(0) != PBI->getSuccessor(1)) { + bool PCondIsTrue = PPBI->getSuccessor(0) == Dom; + bool CondIsTrue = PBI->getSuccessor(0) == BB; + // generate an new compare instruction, this new instruction is + // the combined And Prediction of PPBI and PBI, it can be used to + // calculate implication with BI's condition. + Value *CombinedAndCmp = GetCombinedCompareInstruction( + Builder, PPBI->getCondition(), PBI->getCondition(), PCondIsTrue, + CondIsTrue); + if (CombinedAndCmp) { + Implication = isImpliedCondition( + CombinedAndCmp, BI->getCondition(), DL, CondIsTrue); + // CombinedAndCmp is a helper compare instruction for implication + // calculation only, need delete it after result returned. + RecursivelyDeleteTriviallyDeadInstructions(CombinedAndCmp); + } + } + } else { + bool CondIsTrue = PBI->getSuccessor(0) == BB; + Implication = isImpliedCondition(PBI->getCondition(), + BI->getCondition(), DL, CondIsTrue); + } + if (Implication) { // Turn this into a branch on constant. auto *OldCond = BI->getCondition(); Index: llvm/test/Transforms/SimplifyCFG/branch-fold-three.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimplifyCFG/branch-fold-three.ll @@ -0,0 +1,29 @@ +; RUN: opt %s -S -simplifycfg | not grep "call void @_Z4bar3ii" +define void @_Z3fooii(i32 signext %s, i32 signext %t) local_unnamed_addr #0 { +entry: + %cmp = icmp slt i32 %s, %t + br i1 %cmp, label %if.end6, label %if.then + +if.then: ; preds = %entry + tail call void @_Z4bar1ii(i32 signext %s, i32 signext %t) + %cmp1 = icmp eq i32 %s, %t + br i1 %cmp1, label %if.end6, label %if.then2 + +if.then2: ; preds = %if.then + tail call void @_Z4bar2ii(i32 signext %s, i32 signext %t) + %cmp3 = icmp sgt i32 %s, %t + br i1 %cmp3, label %if.end6, label %if.then4 + +if.then4: ; preds = %if.then2 + tail call void @_Z4bar3ii(i32 signext %s, i32 signext %t) + br label %if.end6 + +if.end6: ; preds = %if.then2, %if.then, %entry, %if.then4 + ret void +} + +declare void @_Z4bar1ii(i32 signext, i32 signext) local_unnamed_addr #1 + +declare void @_Z4bar2ii(i32 signext, i32 signext) local_unnamed_addr #1 + +declare void @_Z4bar3ii(i32 signext, i32 signext) local_unnamed_addr #1