diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -88,6 +88,11 @@ /// Match an arbitrary Constant and ignore it. inline class_match m_Constant() { return class_match(); } +/// Match an arbitrary basic block value and ignore it. +inline class_match m_BasicBlock() { + return class_match(); +} + /// Inverting matcher template struct match_unless { Ty M; @@ -563,6 +568,12 @@ /// Match a ConstantFP, capturing the value if we match. inline bind_ty m_ConstantFP(ConstantFP *&C) { return C; } +/// Match a basic block value, capturing it if we match. +inline bind_ty m_BasicBlock(BasicBlock *&V) { return V; } +inline bind_ty m_BasicBlock(const BasicBlock *&V) { + return V; +} + /// Match a specified Value*. struct specificval_ty { const Value *Val; @@ -656,6 +667,24 @@ /// ConstantInts wider than 64-bits. inline bind_const_intval_ty m_ConstantInt(uint64_t &V) { return V; } +/// Match a specified basic block value. +struct specific_bbval { + BasicBlock *Val; + + specific_bbval(BasicBlock *Val) : Val(Val) {} + + template bool match(ITy *V) { + const auto *BB = dyn_cast(V); + return BB && BB == Val; + } +}; + +/// Match a specific integer value or vector with all elements equal to +/// the value. +inline specific_bbval m_SpecificBB(BasicBlock *BB) { + return specific_bbval(BB); +} + //===----------------------------------------------------------------------===// // Matcher for any binary operator. // @@ -1334,27 +1363,34 @@ inline br_match m_UnconditionalBr(BasicBlock *&Succ) { return br_match(Succ); } -template struct brc_match { +template +struct brc_match { Cond_t Cond; - BasicBlock *&T, *&F; + TrueBlock_t T; + FalseBlock_t F; - brc_match(const Cond_t &C, BasicBlock *&t, BasicBlock *&f) + brc_match(const Cond_t &C, const TrueBlock_t &t, const FalseBlock_t &f) : Cond(C), T(t), F(f) {} template bool match(OpTy *V) { if (auto *BI = dyn_cast(V)) - if (BI->isConditional() && Cond.match(BI->getCondition())) { - T = BI->getSuccessor(0); - F = BI->getSuccessor(1); - return true; - } + if (BI->isConditional() && Cond.match(BI->getCondition())) + return T.match(BI->getSuccessor(0)) && F.match(BI->getSuccessor(1)); return false; } }; template -inline brc_match m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F) { - return brc_match(C, T, F); +inline brc_match, bind_ty> +m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F) { + return brc_match, bind_ty>( + C, m_BasicBlock(T), m_BasicBlock(F)); +} + +template +inline brc_match +m_Br(const Cond_t &C, const TrueBlock_t &T, const FalseBlock_t &F) { + return brc_match(C, T, F); } //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp --- a/llvm/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionExpander.cpp @@ -2105,11 +2105,10 @@ for (BasicBlock *BB : ExitingBlocks) { ICmpInst::Predicate Pred; Instruction *LHS, *RHS; - BasicBlock *TrueBB, *FalseBB; if (!match(BB->getTerminator(), m_Br(m_ICmp(Pred, m_Instruction(LHS), m_Instruction(RHS)), - TrueBB, FalseBB))) + m_BasicBlock(), m_BasicBlock()))) continue; if (SE.getSCEV(LHS) == S && SE.DT.dominates(LHS, At)) diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -121,14 +121,13 @@ BasicBlock *GuardBB = Phi.getIncomingBlock(RotSrc == P1); BasicBlock *RotBB = Phi.getIncomingBlock(RotSrc != P1); Instruction *TermI = GuardBB->getTerminator(); - BasicBlock *TrueBB, *FalseBB; ICmpInst::Predicate Pred; - if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(RotAmt), m_ZeroInt()), TrueBB, - FalseBB))) + BasicBlock *PhiBB = Phi.getParent(); + if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(RotAmt), m_ZeroInt()), + m_SpecificBB(PhiBB), m_SpecificBB(RotBB)))) return false; - BasicBlock *PhiBB = Phi.getParent(); - if (Pred != CmpInst::ICMP_EQ || TrueBB != PhiBB || FalseBB != RotBB) + if (Pred != CmpInst::ICMP_EQ) return false; // We matched a variation of this IR pattern: diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -2557,9 +2557,7 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { // Change br (not X), label True, label False to: br X, label False, True Value *X = nullptr; - BasicBlock *TrueDest; - BasicBlock *FalseDest; - if (match(&BI, m_Br(m_Not(m_Value(X)), TrueDest, FalseDest)) && + if (match(&BI, m_Br(m_Not(m_Value(X)), m_BasicBlock(), m_BasicBlock())) && !isa(X)) { // Swap Destinations and condition... BI.setCondition(X); @@ -2577,8 +2575,8 @@ // Canonicalize, for example, icmp_ne -> icmp_eq or fcmp_one -> fcmp_oeq. CmpInst::Predicate Pred; - if (match(&BI, m_Br(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), TrueDest, - FalseDest)) && + if (match(&BI, m_Br(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), + m_BasicBlock(), m_BasicBlock())) && !isCanonicalPredicate(Pred)) { // Swap destinations and condition. CmpInst *Cond = cast(BI.getCondition()); diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -1045,6 +1045,28 @@ EXPECT_FALSE(match(V3, m_FNeg(m_Value(Match)))); } +TEST_F(PatternMatchTest, CondBranchTest) { + BasicBlock *TrueBB = BasicBlock::Create(Ctx, "TrueBB", F); + BasicBlock *FalseBB = BasicBlock::Create(Ctx, "FalseBB", F); + Value *Br1 = IRB.CreateCondBr(IRB.getTrue(), TrueBB, FalseBB); + + EXPECT_TRUE(match(Br1, m_Br(m_Value(), m_BasicBlock(), m_BasicBlock()))); + + BasicBlock *A, *B; + EXPECT_TRUE(match(Br1, m_Br(m_Value(), m_BasicBlock(A), m_BasicBlock(B)))); + EXPECT_EQ(TrueBB, A); + EXPECT_EQ(FalseBB, B); + + EXPECT_FALSE( + match(Br1, m_Br(m_Value(), m_SpecificBB(FalseBB), m_BasicBlock()))); + EXPECT_FALSE( + match(Br1, m_Br(m_Value(), m_BasicBlock(), m_SpecificBB(TrueBB)))); + EXPECT_FALSE( + match(Br1, m_Br(m_Value(), m_SpecificBB(FalseBB), m_BasicBlock(TrueBB)))); + EXPECT_TRUE( + match(Br1, m_Br(m_Value(), m_SpecificBB(TrueBB), m_BasicBlock(FalseBB)))); +} + template struct MutableConstTest : PatternMatchTest { }; typedef ::testing::Types,