Index: include/llvm/Transforms/Utils/PredicateInfo.h =================================================================== --- include/llvm/Transforms/Utils/PredicateInfo.h +++ include/llvm/Transforms/Utils/PredicateInfo.h @@ -92,7 +92,7 @@ class raw_ostream; class OrderedBasicBlock; -enum PredicateType { PT_Branch, PT_Assume }; +enum PredicateType { PT_Branch, PT_Assume, PT_Switch }; // Base class for all predicate information we provide. // All of our predicate information has at least a comparison. @@ -103,24 +103,35 @@ // This can be use by passes, when destroying predicateinfo, to know // whether they can just drop the intrinsic, or have to merge metadata. Value *OriginalOp; - CmpInst *Comparison; PredicateBase(const PredicateBase &) = delete; PredicateBase &operator=(const PredicateBase &) = delete; PredicateBase() = delete; protected: - PredicateBase(PredicateType PT, Value *Op, CmpInst *Comparison) - : Type(PT), OriginalOp(Op), Comparison(Comparison) {} + PredicateBase(PredicateType PT, Value *Op) : Type(PT), OriginalOp(Op) {} +}; + +class PredicateWithComparison : public PredicateBase { +public: + CmpInst *Comparison; + static inline bool classof(const PredicateBase *PB) { + return PB->Type == PT_Assume || PB->Type == PT_Branch; + } + +protected: + PredicateWithComparison(PredicateType PT, Value *Op, CmpInst *Comparison) + : PredicateBase(PT, Op), Comparison(Comparison) {} }; // Provides predicate information for assumes. Since assumes are always true, // we simply provide the assume instruction, so you can tell your relative // position to it. -class PredicateAssume : public PredicateBase { +class PredicateAssume : public PredicateWithComparison { public: IntrinsicInst *AssumeInst; PredicateAssume(Value *Op, IntrinsicInst *AssumeInst, CmpInst *Comparison) - : PredicateBase(PT_Assume, Op, Comparison), AssumeInst(AssumeInst) {} + : PredicateWithComparison(PT_Assume, Op, Comparison), + AssumeInst(AssumeInst) {} PredicateAssume() = delete; static inline bool classof(const PredicateBase *PB) { return PB->Type == PT_Assume; @@ -128,7 +139,7 @@ }; // Provides predicate information for branches. -class PredicateBranch : public PredicateBase { +class PredicateBranch : public PredicateWithComparison { public: // This is the block that is conditional upon the comparison. BasicBlock *BranchBB; @@ -138,7 +149,7 @@ bool TrueEdge; PredicateBranch(Value *Op, BasicBlock *BranchBB, BasicBlock *SplitBB, CmpInst *Comparison, bool TakenEdge) - : PredicateBase(PT_Branch, Op, Comparison), BranchBB(BranchBB), + : PredicateWithComparison(PT_Branch, Op, Comparison), BranchBB(BranchBB), SplitBB(SplitBB), TrueEdge(TakenEdge) {} PredicateBranch() = delete; static inline bool classof(const PredicateBase *PB) { @@ -146,6 +157,25 @@ } }; +class PredicateSwitch : public PredicateBase { +public: + // This is the block of the switch. + BasicBlock *SwitchBB; + // This is the target block of the case statement. + BasicBlock *TargetBB; + // This is the case value for the switch to that block. + Value *CaseValue; + // This is the switch instruction. + SwitchInst *Switch; + PredicateSwitch(Value *Op, BasicBlock *SwitchBB, BasicBlock *TargetBB, + Value *CaseValue, SwitchInst *SI) + : PredicateBase(PT_Switch, Op), SwitchBB(SwitchBB), TargetBB(TargetBB), + CaseValue(CaseValue), Switch(SI) {} + static inline bool classof(const PredicateBase *PB) { + return PB->Type == PT_Switch; + } +}; + // This name is used in a few places, so kick it into their own namespace namespace PredicateInfoClasses { struct ValueDFS; @@ -188,6 +218,7 @@ void buildPredicateInfo(); void processAssume(IntrinsicInst *, BasicBlock *, SmallPtrSetImpl &); void processBranch(BranchInst *, BasicBlock *, SmallPtrSetImpl &); + void processSwitch(SwitchInst *, BasicBlock *, SmallPtrSetImpl &); void renameUses(SmallPtrSetImpl &); using ValueDFS = PredicateInfoClasses::ValueDFS; typedef SmallVectorImpl ValueDFSStack; @@ -216,7 +247,7 @@ DenseMap> OBBMap; // The set of edges along which we can only handle phi uses, due to critical // edges. - DenseSet PhiUsesOnly; + DenseSet> PhiUsesOnly; }; // This pass does eager building and then printing of PredicateInfo. It is used Index: lib/Transforms/Scalar/NewGVN.cpp =================================================================== --- lib/Transforms/Scalar/NewGVN.cpp +++ lib/Transforms/Scalar/NewGVN.cpp @@ -839,51 +839,65 @@ NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) { if (auto *PI = PredInfo->getPredicateInfoFor(I)) { DEBUG(dbgs() << "Found predicate info from instruction !\n"); - auto *CopyOf = I->getOperand(0); - auto *Cmp = PI->Comparison; - // If this is an assume predicate and a copy of a comparison, it must be - // true. - if (isa(PI) && CopyOf == Cmp) - return createConstantExpression(ConstantInt::getTrue(Cmp->getType())); - - Value *FirstOp = lookupOperandLeader(Cmp->getOperand(0)); - Value *SecondOp = lookupOperandLeader(Cmp->getOperand(1)); - // Sort the ops - CmpInst::Predicate Predicate = Cmp->getPredicate(); - // FIXME: We should really be ranking them here - if (getRank(FirstOp) > getRank(SecondOp)) { - std::swap(FirstOp, SecondOp); - Predicate = CmpInst::getSwappedPredicate(Predicate); - } + if (auto *PWC = dyn_cast(PI)) { + auto *CopyOf = I->getOperand(0); + auto *Cmp = PWC->Comparison; + // If this is an assume predicate and a copy of a comparison, it must be + // true. + if (isa(PI) && CopyOf == Cmp) + return createConstantExpression(ConstantInt::getTrue(Cmp->getType())); + + Value *FirstOp = lookupOperandLeader(Cmp->getOperand(0)); + Value *SecondOp = lookupOperandLeader(Cmp->getOperand(1)); + // Sort the ops + CmpInst::Predicate Predicate = Cmp->getPredicate(); + // FIXME: We should really be ranking them here + if (getRank(FirstOp) > getRank(SecondOp)) { + std::swap(FirstOp, SecondOp); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } - if (isa(PI)) { - // If the comparison is true when the operands are equal, then we know the - // operands are equal, because assumes must always be true. - if (CmpInst::isTrueWhenEqual(Predicate)) - if (auto *C = dyn_cast(FirstOp)) - return createConstantExpression(C); - return createVariableExpression(FirstOp); - } else if (const auto *PBranch = dyn_cast(PI)) { - // If this is a copy of the comparison, it's value is determined by the - // edge. - if (CopyOf == Cmp) { - if (PBranch->TrueEdge) - return createConstantExpression(ConstantInt::getTrue(Cmp->getType())); - return createConstantExpression(ConstantInt::getFalse(Cmp->getType())); - } else if ((PBranch->TrueEdge && CmpInst::isTrueWhenEqual(Predicate)) || - (!PBranch->TrueEdge && CmpInst::isFalseWhenEqual(Predicate))) { - // If we are *not* a copy of the comparison, we may equal to the other - // operand when - // the predicate implies something about equality of operations. - if (auto *C = dyn_cast(FirstOp)) - return createConstantExpression(C); + if (isa(PI)) { + // If the comparison is true when the operands are equal, then we know + // the + // operands are equal, because assumes must always be true. + if (CmpInst::isTrueWhenEqual(Predicate)) + if (auto *C = dyn_cast(FirstOp)) + return createConstantExpression(C); return createVariableExpression(FirstOp); - } else if (((PBranch->TrueEdge && Predicate == CmpInst::FCMP_OEQ) || - (!PBranch->TrueEdge && Predicate == CmpInst::FCMP_UNE)) && - isa(FirstOp) && - !cast(FirstOp)->isZero()) { - return createConstantExpression(cast(FirstOp)); + } else if (const auto *PBranch = dyn_cast(PI)) { + // If this is a copy of the comparison, it's value is determined by the + // edge. + if (CopyOf == Cmp) { + if (PBranch->TrueEdge) + return createConstantExpression( + ConstantInt::getTrue(Cmp->getType())); + return createConstantExpression( + ConstantInt::getFalse(Cmp->getType())); + } else if ((PBranch->TrueEdge && CmpInst::isTrueWhenEqual(Predicate)) || + (!PBranch->TrueEdge && + CmpInst::isFalseWhenEqual(Predicate))) { + // If we are *not* a copy of the comparison, we may equal to the other + // operand when + // the predicate implies something about equality of operations. + if (auto *C = dyn_cast(FirstOp)) + return createConstantExpression(C); + return createVariableExpression(FirstOp); + } else if (((PBranch->TrueEdge && Predicate == CmpInst::FCMP_OEQ) || + (!PBranch->TrueEdge && Predicate == CmpInst::FCMP_UNE)) && + isa(FirstOp) && + !cast(FirstOp)->isZero()) { + return createConstantExpression(cast(FirstOp)); + } } + } else if (auto *PSwitch = dyn_cast(PI)) { + // For switch statements, we know the value of the condition, and that is + // what Op must be. + assert(I->getOperand(0) == PSwitch->Switch->getCondition() && + "Found a predicateswitch info that seems to be about a non-switch " + "condition"); + assert(isa(PSwitch->CaseValue)); + return createConstantExpression(cast(PSwitch->CaseValue)); } } return nullptr; Index: lib/Transforms/Utils/PredicateInfo.cpp =================================================================== --- lib/Transforms/Utils/PredicateInfo.cpp +++ lib/Transforms/Utils/PredicateInfo.cpp @@ -48,6 +48,41 @@ static cl::opt VerifyPredicateInfo( "verify-predicateinfo", cl::init(false), cl::Hidden, cl::desc("Verify PredicateInfo in legacy printer pass.")); +namespace { + +// Given a predicate info that is a type of branching terminator, get the +// branching block. +const BasicBlock *getBranchBlock(PredicateBase *PB) { + if (auto *PBranch = dyn_cast(PB)) + return PBranch->BranchBB; + if (auto *PSwitch = dyn_cast(PB)) + return PSwitch->SwitchBB; + llvm_unreachable("Only branches and switches should have PHIOnly defs that " + "require branch blocks"); +} + +// Given a predicate info that is a type of branching terminator, get the +// branching terminator. +static Instruction *getBranchTerminator(PredicateBase *PB) { + if (auto *PBranch = dyn_cast(PB)) + return PBranch->BranchBB->getTerminator(); + if (auto *PSwitch = dyn_cast(PB)) + return PSwitch->SwitchBB->getTerminator(); + llvm_unreachable("Not a predicate info type we know how to handle"); +} + +// Given a predicate info that is a type of branching terminator, get the +// edge this predicate info represents +const std::pair +getBlockEdge(const PredicateBase *PB) { + if (auto *PBranch = dyn_cast(PB)) + return std::make_pair(PBranch->BranchBB, PBranch->SplitBB); + if (auto *PSwitch = dyn_cast(PB)) + return std::make_pair(PSwitch->SwitchBB, PSwitch->TargetBB); + llvm_unreachable("Unhandle predicate info type"); +} +} + namespace llvm { namespace PredicateInfoClasses { enum LocalNum { @@ -106,15 +141,14 @@ } // For a phi use, or a non-materialized def, return the edge it represents. - const std::pair + const std::pair getBlockEdge(const ValueDFS &VD) const { if (!VD.Def && VD.Use) { auto *PHI = cast(VD.Use->getUser()); return std::make_pair(PHI->getIncomingBlock(*VD.Use), PHI->getParent()); } // This is really a non-materialized def. - auto *PBranch = cast(VD.PInfo); - return std::make_pair(PBranch->BranchBB, PBranch->SplitBB); + return ::getBlockEdge(VD.PInfo); } // For two phi related values, return the ordering. @@ -208,12 +242,10 @@ auto *PHI = dyn_cast(VDUse.Use->getUser()); if (!PHI) return false; - // The only phionly defs should be branch info. - auto *PBranch = dyn_cast(Stack.back().PInfo); - assert(PBranch && "Only branches should have PHIOnly defs"); - // Check edge - BasicBlock *EdgePred = PHI->getIncomingBlock(*VDUse.Use); - if (EdgePred != PBranch->BranchBB) + // The only phionly defs should be branch or switch info. + // Check that this is our edge. + const BasicBlock *EdgePred = PHI->getIncomingBlock(*VDUse.Use); + if (EdgePred != getBranchBlock(Stack.back().PInfo)) return false; } @@ -374,6 +406,36 @@ } } } +// Process a block terminating switch, and place relevant operations to be +// renamed into OpsToRename. +void PredicateInfo::processSwitch(SwitchInst *SI, BasicBlock *BranchBB, + SmallPtrSetImpl &OpsToRename) { + Value *Op = SI->getCondition(); + if ((!isa(Op) && !isa(Op)) || Op->hasOneUse()) + return; + + // Remember how many outgoing edges there are to every successor. + SmallDenseMap SwitchEdges; + for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) { + BasicBlock *TargetBlock = SI->getSuccessor(i); + ++SwitchEdges[TargetBlock]; + } + + // Now propagate info for each case value + for (auto i = SI->case_begin(), e = SI->case_end(); i != e; ++i) { + BasicBlock *TargetBlock = i.getCaseSuccessor(); + if (SwitchEdges.lookup(TargetBlock) == 1) { + OpsToRename.insert(Op); + auto &OperandInfo = getOrCreateValueInfo(Op); + PredicateSwitch *PS = new PredicateSwitch( + Op, SI->getParent(), TargetBlock, i.getCaseValue(), SI); + AllInfos.push_back(PS); + OperandInfo.Infos.push_back(PS); + if (!TargetBlock->getSinglePredecessor()) + PhiUsesOnly.insert({BranchBB, TargetBlock}); + } + } +} // Build predicate info for our function void PredicateInfo::buildPredicateInfo() { @@ -387,6 +449,8 @@ if (!BI->isConditional()) continue; processBranch(BI, BranchBB, OpsToRename); + } else if (auto *SI = dyn_cast(BranchBB->getTerminator())) { + processSwitch(SI, BranchBB, OpsToRename); } } for (auto &Assume : AC.assumptions()) { @@ -396,6 +460,9 @@ // Now rename all our operations. renameUses(OpsToRename); } + +// Given the renaming stack, make all the operands currently on the stack real +// by inserting them into the IR. Return the last operation's value. Value *PredicateInfo::materializeStack(unsigned int &Counter, ValueDFSStack &RenameStack, Value *OrigOp) { @@ -420,9 +487,8 @@ // to ensure we dominate all of our uses. Always insert right before the // relevant instruction (terminator, assume), so that we insert in proper // order in the case of multiple predicateinfo in the same block. - if (isa(ValInfo)) { - auto *PBranch = cast(ValInfo); - IRBuilder<> B(PBranch->BranchBB->getTerminator()); + if (isa(ValInfo) || isa(ValInfo)) { + IRBuilder<> B(getBranchTerminator(ValInfo)); Function *IF = Intrinsic::getDeclaration( F.getParent(), Intrinsic::ssa_copy, Op->getType()); Value *PIC = B.CreateCall(IF, Op, Op->getName() + "." + Twine(Counter++)); @@ -488,14 +554,15 @@ VD.DFSOut = DomNode->getDFSNumOut(); VD.PInfo = PossibleCopy; OrderedUses.push_back(VD); - } else if (const auto *PBranch = - dyn_cast(PossibleCopy)) { + } else if (isa(PossibleCopy) || + isa(PossibleCopy)) { // If we can only do phi uses, we treat it like it's in the branch // block, and handle it specially. We know that it goes last, and only // dominate phi uses. - if (PhiUsesOnly.count({PBranch->BranchBB, PBranch->SplitBB})) { + auto BlockEdge = getBlockEdge(PossibleCopy); + if (PhiUsesOnly.count(BlockEdge)) { VD.LocalNum = LN_Last; - auto *DomNode = DT.getNode(PBranch->BranchBB); + auto *DomNode = DT.getNode(BlockEdge.first); if (DomNode) { VD.DFSIn = DomNode->getDFSNumIn(); VD.DFSOut = DomNode->getDFSNumOut(); @@ -508,7 +575,7 @@ // insertion in the branch block). // Insert a possible copy at the split block and before the branch. VD.LocalNum = LN_First; - auto *DomNode = DT.getNode(PBranch->SplitBB); + auto *DomNode = DT.getNode(BlockEdge.second); if (DomNode) { VD.DFSIn = DomNode->getDFSNumIn(); VD.DFSOut = DomNode->getDFSNumOut();