Index: include/llvm/Transforms/Scalar/ConstantHoisting.h =================================================================== --- include/llvm/Transforms/Scalar/ConstantHoisting.h +++ include/llvm/Transforms/Scalar/ConstantHoisting.h @@ -36,6 +36,7 @@ #ifndef LLVM_TRANSFORMS_SCALAR_CONSTANTHOISTING_H #define LLVM_TRANSFORMS_SCALAR_CONSTANTHOISTING_H +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/PassManager.h" @@ -98,7 +99,7 @@ // Glue for old PM. bool runImpl(Function &F, TargetTransformInfo &TTI, DominatorTree &DT, - BasicBlock &Entry); + BlockFrequencyInfo *BFI, BasicBlock &Entry); void releaseMemory() { ConstantVec.clear(); @@ -112,6 +113,7 @@ const TargetTransformInfo *TTI; DominatorTree *DT; + BlockFrequencyInfo *BFI; BasicBlock *Entry; /// Keeps track of constant candidates found in the function. @@ -124,8 +126,8 @@ SmallVector ConstantVec; Instruction *findMatInsertPt(Instruction *Inst, unsigned Idx = ~0U) const; - Instruction *findConstantInsertionPoint( - const consthoist::ConstantInfo &ConstInfo) const; + SmallPtrSet + findConstantInsertionPoint(const consthoist::ConstantInfo &ConstInfo) const; void collectConstantCandidates(ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx, ConstantInt *ConstInt); Index: lib/Transforms/Scalar/ConstantHoisting.cpp =================================================================== --- lib/Transforms/Scalar/ConstantHoisting.cpp +++ lib/Transforms/Scalar/ConstantHoisting.cpp @@ -53,6 +53,12 @@ STATISTIC(NumConstantsHoisted, "Number of constants hoisted"); STATISTIC(NumConstantsRebased, "Number of constants rebased"); +static cl::opt ConstHoistWithBlockFrequency( + "consthoist-with-block-frequency", cl::init(true), cl::Hidden, + cl::desc("Enable the use of the block frequency analysis to reduce the " + "chance to execute const materialization more frequently than " + "without hoisting.")); + namespace { /// \brief The constant hoisting pass. class ConstantHoistingLegacyPass : public FunctionPass { @@ -68,6 +74,8 @@ void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + if (ConstHoistWithBlockFrequency) + AU.addRequired(); AU.addRequired(); AU.addRequired(); } @@ -82,6 +90,7 @@ char ConstantHoistingLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(ConstantHoistingLegacyPass, "consthoist", "Constant Hoisting", false, false) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(ConstantHoistingLegacyPass, "consthoist", @@ -99,9 +108,13 @@ DEBUG(dbgs() << "********** Begin Constant Hoisting **********\n"); DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n'); - bool MadeChange = Impl.runImpl( - Fn, getAnalysis().getTTI(Fn), - getAnalysis().getDomTree(), Fn.getEntryBlock()); + bool MadeChange = + Impl.runImpl(Fn, getAnalysis().getTTI(Fn), + getAnalysis().getDomTree(), + ConstHoistWithBlockFrequency + ? &getAnalysis().getBFI() + : nullptr, + Fn.getEntryBlock()); if (MadeChange) { DEBUG(dbgs() << "********** Function after Constant Hoisting: " @@ -140,33 +153,140 @@ return IDom->getTerminator(); } +/// \brief Given \p BBs as input, find another set of BBs which collectively +/// dominates \p BBs and have the minimal sum of frequencies. Return the BB +/// set found in \p BBs. +void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, + BasicBlock *Entry, + SmallPtrSet &BBs) { + assert(!BBs.count(Entry) && "Assume Entry is not in BBs"); + // Nodes on the current path to the root. + SmallPtrSet Path; + // Candidates includes any block 'BB' in set 'BBs' that is not strictly + // dominated by any other blocks in set 'BBs', and all nodes in the path + // in the dominator tree from Entry to 'BB'. + SmallPtrSet Candidates; + for (auto BB : BBs) { + Path.clear(); + // Walk up the dominator tree until Entry or another BB in BBs + // is reached. Insert the nodes on the way to the Path. + BasicBlock *Node = BB; + // The "Path" is a candidate path to be added into Candidates set. + bool isCandidate = false; + do { + Path.insert(Node); + if (Node == Entry || Candidates.count(Node)) { + isCandidate = true; + break; + } + assert(DT.getNode(Node)->getIDom() && + "Entry doens't dominate current Node"); + Node = DT.getNode(Node)->getIDom()->getBlock(); + } while (!BBs.count(Node)); + + // If isCandidate is false, Node is another Block in BBs dominating + // current 'BB'. Drop the nodes on the Path. + if (!isCandidate) + continue; + + // Add nodes on the Path into Candidates. + Candidates.insert(Path.begin(), Path.end()); + } + + // Sort the nodes in Candidates in top-down order and save the nodes + // in Orders. + unsigned Idx = 0; + SmallVector Orders; + Orders.push_back(Entry); + while (Idx != Orders.size()) { + BasicBlock *Node = Orders[Idx++]; + for (auto ChildDomNode : DT.getNode(Node)->getChildren()) { + if (Candidates.count(ChildDomNode->getBlock())) + Orders.push_back(ChildDomNode->getBlock()); + } + } + + // Visit Orders in bottom-up order. + typedef std::pair, BlockFrequency> + InsertPtsCostPair; + // InsertPtsMap is a map from a BB to the best insertion points for the + // subtree of BB (subtree not including the BB itself). + DenseMap InsertPtsMap; + InsertPtsMap.reserve(Orders.size() + 1); + for (auto RIt = Orders.rbegin(); RIt != Orders.rend(); RIt++) { + BasicBlock *Node = *RIt; + bool NodeInBBs = BBs.count(Node); + SmallPtrSet &InsertPts = InsertPtsMap[Node].first; + BlockFrequency &InsertPtsFreq = InsertPtsMap[Node].second; + + // Return the optimal insert points in BBs. + if (Node == Entry) { + BBs.clear(); + if (InsertPtsFreq > BFI.getBlockFreq(Node)) + BBs.insert(Entry); + else + BBs.insert(InsertPts.begin(), InsertPts.end()); + break; + } + + BasicBlock *Parent = DT.getNode(Node)->getIDom()->getBlock(); + // Initially, ParentInsertPts is empty and ParentPtsFreq is 0. Every child + // will update its parent's ParentInsertPts and ParentPtsFreq. + SmallPtrSet &ParentInsertPts = InsertPtsMap[Parent].first; + BlockFrequency &ParentPtsFreq = InsertPtsMap[Parent].second; + // Choose to insert in Node or in subtree of Node. + if (InsertPtsFreq > BFI.getBlockFreq(Node) || NodeInBBs) { + ParentInsertPts.insert(Node); + ParentPtsFreq += BFI.getBlockFreq(Node); + } else { + ParentInsertPts.insert(InsertPts.begin(), InsertPts.end()); + ParentPtsFreq += InsertPtsFreq; + } + } +} + /// \brief Find an insertion point that dominates all uses. -Instruction *ConstantHoistingPass::findConstantInsertionPoint( +SmallPtrSet ConstantHoistingPass::findConstantInsertionPoint( const ConstantInfo &ConstInfo) const { assert(!ConstInfo.RebasedConstants.empty() && "Invalid constant info entry."); // Collect all basic blocks. SmallPtrSet BBs; + SmallPtrSet InsertPts; for (auto const &RCI : ConstInfo.RebasedConstants) for (auto const &U : RCI.Uses) BBs.insert(findMatInsertPt(U.Inst, U.OpndIdx)->getParent()); - if (BBs.count(Entry)) - return &Entry->front(); + if (BBs.count(Entry)) { + InsertPts.insert(&Entry->front()); + return InsertPts; + } + + if (BFI) { + findBestInsertionSet(*DT, *BFI, Entry, BBs); + for (auto BB : BBs) { + Instruction &FirstInst = BB->front(); + InsertPts.insert(findMatInsertPt(&FirstInst)); + } + return InsertPts; + } while (BBs.size() >= 2) { BasicBlock *BB, *BB1, *BB2; BB1 = *BBs.begin(); BB2 = *std::next(BBs.begin()); BB = DT->findNearestCommonDominator(BB1, BB2); - if (BB == Entry) - return &Entry->front(); + if (BB == Entry) { + InsertPts.insert(&Entry->front()); + return InsertPts; + } BBs.erase(BB1); BBs.erase(BB2); BBs.insert(BB); } assert((BBs.size() == 1) && "Expected only one element."); Instruction &FirstInst = (*BBs.begin())->front(); - return findMatInsertPt(&FirstInst); + InsertPts.insert(&FirstInst); + return InsertPts; } @@ -549,29 +669,54 @@ bool MadeChange = false; for (auto const &ConstInfo : ConstantVec) { // Hoist and hide the base constant behind a bitcast. - Instruction *IP = findConstantInsertionPoint(ConstInfo); - IntegerType *Ty = ConstInfo.BaseConstant->getType(); - Instruction *Base = - new BitCastInst(ConstInfo.BaseConstant, Ty, "const", IP); - DEBUG(dbgs() << "Hoist constant (" << *ConstInfo.BaseConstant << ") to BB " - << IP->getParent()->getName() << '\n' << *Base << '\n'); - NumConstantsHoisted++; + SmallPtrSet IPSet = findConstantInsertionPoint(ConstInfo); + assert(!IPSet.empty() && "IPSet is empty"); - // Emit materialization code for all rebased constants. - for (auto const &RCI : ConstInfo.RebasedConstants) { - NumConstantsRebased++; - for (auto const &U : RCI.Uses) - emitBaseConstants(Base, RCI.Offset, U); - } + unsigned UsesNum = 0; + unsigned ReBasesNum = 0; + for (Instruction *IP : IPSet) { + IntegerType *Ty = ConstInfo.BaseConstant->getType(); + Instruction *Base = + new BitCastInst(ConstInfo.BaseConstant, Ty, "const", IP); + DEBUG(dbgs() << "Hoist constant (" << *ConstInfo.BaseConstant + << ") to BB " << IP->getParent()->getName() << '\n' + << *Base << '\n'); + + // Emit materialization code for all rebased constants. + unsigned Uses = 0; + for (auto const &RCI : ConstInfo.RebasedConstants) { + for (auto const &U : RCI.Uses) { + Uses++; + BasicBlock *OrigMatInsertBB = + findMatInsertPt(U.Inst, U.OpndIdx)->getParent(); + // If Base constant is to be inserted in multiple places, + // generate rebase for U using the Base dominating U. + if (IPSet.size() == 1 || + DT->dominates(Base->getParent(), OrigMatInsertBB)) { + emitBaseConstants(Base, RCI.Offset, U); + ReBasesNum++; + } + } + } + UsesNum = Uses; + + // Use the same debug location as the last user of the constant. + assert(!Base->use_empty() && "The use list is empty!?"); + assert(isa(Base->user_back()) && + "All uses should be instructions."); + Base->setDebugLoc(cast(Base->user_back())->getDebugLoc()); + } + (void)UsesNum; + (void)ReBasesNum; + // Expect all uses are rebased after rebase is done. + assert(UsesNum == ReBasesNum && "Not all uses are rebased"); + + NumConstantsHoisted++; - // Use the same debug location as the last user of the constant. - assert(!Base->use_empty() && "The use list is empty!?"); - assert(isa(Base->user_back()) && - "All uses should be instructions."); - Base->setDebugLoc(cast(Base->user_back())->getDebugLoc()); + // Base constant is also included in ConstInfo.RebasedConstants, so + // deduct 1 from ConstInfo.RebasedConstants.size(). + NumConstantsRebased = ConstInfo.RebasedConstants.size() - 1; - // Correct for base constant, which we counted above too. - NumConstantsRebased--; MadeChange = true; } return MadeChange; @@ -587,9 +732,11 @@ /// \brief Optimize expensive integer constants in the given function. bool ConstantHoistingPass::runImpl(Function &Fn, TargetTransformInfo &TTI, - DominatorTree &DT, BasicBlock &Entry) { + DominatorTree &DT, BlockFrequencyInfo *BFI, + BasicBlock &Entry) { this->TTI = &TTI; this->DT = &DT; + this->BFI = BFI; this->Entry = &Entry; // Collect all constant candidates. collectConstantCandidates(Fn); @@ -620,7 +767,10 @@ FunctionAnalysisManager &AM) { auto &DT = AM.getResult(F); auto &TTI = AM.getResult(F); - if (!runImpl(F, TTI, DT, F.getEntryBlock())) + auto BFI = ConstHoistWithBlockFrequency + ? &AM.getResult(F) + : nullptr; + if (!runImpl(F, TTI, DT, BFI, F.getEntryBlock())) return PreservedAnalyses::all(); PreservedAnalyses PA; Index: test/CodeGen/X86/constant-hoisting-bfi.ll =================================================================== --- test/CodeGen/X86/constant-hoisting-bfi.ll +++ test/CodeGen/X86/constant-hoisting-bfi.ll @@ -0,0 +1,115 @@ +; RUN: opt -consthoist -mtriple=x86_64-unknown-linux-gnu -consthoist-with-block-frequency=true -S < %s | FileCheck %s + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" + +; Check when BFI is enabled for constant hoisting, constant 214748364701 +; will not be hoisted to the func entry. +; CHECK-LABEL: @foo( +; CHECK: entry: +; CHECK-NOT: bitcast i64 214748364701 to i64 +; CHECK: if.then: + +; Function Attrs: norecurse nounwind uwtable +define i64 @foo(i64* nocapture %a) { +entry: + %arrayidx = getelementptr inbounds i64, i64* %a, i64 9 + %t0 = load i64, i64* %arrayidx, align 8 + %cmp = icmp slt i64 %t0, 564 + br i1 %cmp, label %if.then, label %if.else5 + +if.then: ; preds = %entry + %arrayidx1 = getelementptr inbounds i64, i64* %a, i64 5 + %t1 = load i64, i64* %arrayidx1, align 8 + %cmp2 = icmp slt i64 %t1, 1009 + br i1 %cmp2, label %if.then3, label %return + +if.then3: ; preds = %if.then + %arrayidx4 = getelementptr inbounds i64, i64* %a, i64 6 + %t2 = load i64, i64* %arrayidx4, align 8 + %inc = add nsw i64 %t2, 1 + store i64 %inc, i64* %arrayidx4, align 8 + br label %return + +if.else5: ; preds = %entry + %arrayidx6 = getelementptr inbounds i64, i64* %a, i64 6 + %t3 = load i64, i64* %arrayidx6, align 8 + %cmp7 = icmp slt i64 %t3, 3512 + br i1 %cmp7, label %if.then8, label %return + +if.then8: ; preds = %if.else5 + %arrayidx9 = getelementptr inbounds i64, i64* %a, i64 7 + %t4 = load i64, i64* %arrayidx9, align 8 + %inc10 = add nsw i64 %t4, 1 + store i64 %inc10, i64* %arrayidx9, align 8 + br label %return + +return: ; preds = %if.else5, %if.then, %if.then8, %if.then3 + %retval.0 = phi i64 [ 214748364701, %if.then3 ], [ 214748364701, %if.then8 ], [ 250148364702, %if.then ], [ 256148364704, %if.else5 ] + ret i64 %retval.0 +} + +; Check when BFI is enabled for constant hoisting, constant 214748364701 +; in while.body will be hoisted to while.body.preheader. 214748364701 in +; if.then16 and if.else10 will be merged and hoisted to the beginning of +; if.else10 because if.else10 dominates if.then16. +; CHECK-LABEL: @goo( +; CHECK: entry: +; CHECK-NOT: bitcast i64 214748364701 to i64 +; CHECK: while.body.preheader: +; CHECK-NEXT: bitcast i64 214748364701 to i64 +; CHECK-NOT: bitcast i64 214748364701 to i64 +; CHECK: if.else10: +; CHECK-NEXT: bitcast i64 214748364701 to i64 +; CHECK-NOT: bitcast i64 214748364701 to i64 +define i64 @goo(i64* nocapture %a) { +entry: + %arrayidx = getelementptr inbounds i64, i64* %a, i64 9 + %t0 = load i64, i64* %arrayidx, align 8 + %cmp = icmp ult i64 %t0, 56 + br i1 %cmp, label %if.then, label %if.else10, !prof !0 + +if.then: ; preds = %entry + %arrayidx1 = getelementptr inbounds i64, i64* %a, i64 5 + %t1 = load i64, i64* %arrayidx1, align 8 + %cmp2 = icmp ult i64 %t1, 10 + br i1 %cmp2, label %while.cond.preheader, label %return, !prof !0 + +while.cond.preheader: ; preds = %if.then + %arrayidx7 = getelementptr inbounds i64, i64* %a, i64 6 + %t2 = load i64, i64* %arrayidx7, align 8 + %cmp823 = icmp ugt i64 %t2, 10000 + br i1 %cmp823, label %while.body.preheader, label %return + +while.body.preheader: ; preds = %while.cond.preheader + br label %while.body + +while.body: ; preds = %while.body.preheader, %while.body + %t3 = phi i64 [ %add, %while.body ], [ %t2, %while.body.preheader ] + %add = add i64 %t3, 214748364701 + %cmp8 = icmp ugt i64 %add, 10000 + br i1 %cmp8, label %while.body, label %while.cond.return.loopexit_crit_edge + +if.else10: ; preds = %entry + %arrayidx11 = getelementptr inbounds i64, i64* %a, i64 6 + %t4 = load i64, i64* %arrayidx11, align 8 + %add2 = add i64 %t4, 214748364701 + %cmp12 = icmp ult i64 %add2, 35 + br i1 %cmp12, label %if.then16, label %return, !prof !0 + +if.then16: ; preds = %if.else10 + %arrayidx17 = getelementptr inbounds i64, i64* %a, i64 7 + %t5 = load i64, i64* %arrayidx17, align 8 + %inc = add i64 %t5, 1 + store i64 %inc, i64* %arrayidx17, align 8 + br label %return + +while.cond.return.loopexit_crit_edge: ; preds = %while.body + store i64 %add, i64* %arrayidx7, align 8 + br label %return + +return: ; preds = %while.cond.preheader, %while.cond.return.loopexit_crit_edge, %if.else10, %if.then, %if.then16 + %retval.0 = phi i64 [ 214748364701, %if.then16 ], [ 0, %if.then ], [ 0, %if.else10 ], [ 0, %while.cond.return.loopexit_crit_edge ], [ 0, %while.cond.preheader ] + ret i64 %retval.0 +} + +!0 = !{!"branch_weights", i32 1, i32 2000} Index: test/CodeGen/X86/fold-tied-op.ll =================================================================== --- test/CodeGen/X86/fold-tied-op.ll +++ test/CodeGen/X86/fold-tied-op.ll @@ -7,7 +7,6 @@ ; CHECK-LABEL: fn1 ; CHECK: addl {{.*#+}} 4-byte Folded Reload -; CHECK: addl {{.*#+}} 4-byte Folded Reload ; CHECK: imull {{.*#+}} 4-byte Folded Reload ; CHECK: orl {{.*#+}} 4-byte Folded Reload ; CHECK: retl