diff --git a/llvm/include/llvm/Transforms/Scalar/ConstantHoisting.h b/llvm/include/llvm/Transforms/Scalar/ConstantHoisting.h --- a/llvm/include/llvm/Transforms/Scalar/ConstantHoisting.h +++ b/llvm/include/llvm/Transforms/Scalar/ConstantHoisting.h @@ -36,6 +36,7 @@ #ifndef LLVM_TRANSFORMS_SCALAR_CONSTANTHOISTING_H #define LLVM_TRANSFORMS_SCALAR_CONSTANTHOISTING_H +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/PointerUnion.h" @@ -168,9 +169,13 @@ /// Keep track of cast instructions we already cloned. MapVector ClonedCastMap; + void collectMatInsertPts( + const consthoist::RebasedConstantListType &RebasedConstants, + SmallVectorImpl &MatInsertPts) const; Instruction *findMatInsertPt(Instruction *Inst, unsigned Idx = ~0U) const; SetVector - findConstantInsertionPoint(const consthoist::ConstantInfo &ConstInfo) const; + findConstantInsertionPoint(const consthoist::ConstantInfo &ConstInfo, + const ArrayRef MatInsertPts) const; void collectConstantCandidates(ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx, ConstantInt *ConstInt); @@ -197,9 +202,11 @@ struct UserAdjustment { Constant *Offset; Type *Ty; + Instruction *MatInsertPt; const consthoist::ConstantUser User; - UserAdjustment(Constant *O, Type *T, consthoist::ConstantUser U) - : Offset(O), Ty(T), User(U) {} + UserAdjustment(Constant *O, Type *T, Instruction *I, + consthoist::ConstantUser U) + : Offset(O), Ty(T), MatInsertPt(I), User(U) {} }; void emitBaseConstants(Instruction *Base, UserAdjustment *Adj); // If BaseGV is nullptr, emit Constant Integer base; otherwise emit diff --git a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp --- a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -160,6 +160,14 @@ return MadeChange; } +void ConstantHoistingPass::collectMatInsertPts( + const RebasedConstantListType &RebasedConstants, + SmallVectorImpl &MatInsertPts) const { + for (const RebasedConstantInfo &RCI : RebasedConstants) + for (const ConstantUser &U : RCI.Uses) + MatInsertPts.emplace_back(findMatInsertPt(U.Inst, U.OpndIdx)); +} + /// Find the constant materialization insertion point. Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst, unsigned Idx) const { @@ -307,14 +315,15 @@ /// Find an insertion point that dominates all uses. SetVector ConstantHoistingPass::findConstantInsertionPoint( - const ConstantInfo &ConstInfo) const { + const ConstantInfo &ConstInfo, + const ArrayRef MatInsertPts) const { assert(!ConstInfo.RebasedConstants.empty() && "Invalid constant info entry."); // Collect all basic blocks. SetVector BBs; SetVector InsertPts; - for (auto const &RCI : ConstInfo.RebasedConstants) - for (auto const &U : RCI.Uses) - BBs.insert(findMatInsertPt(U.Inst, U.OpndIdx)->getParent()); + + for (Instruction *MatInsertPt : MatInsertPts) + BBs.insert(MatInsertPt->getParent()); if (BBs.count(Entry)) { InsertPts.insert(&Entry->front()); @@ -750,20 +759,18 @@ Adj->Offset = ConstantInt::get(Type::getInt32Ty(*Ctx), 0); if (Adj->Offset) { - Instruction *InsertionPt = - findMatInsertPt(Adj->User.Inst, Adj->User.OpndIdx); if (Adj->Ty) { // Constant being rebased is a ConstantExpr. PointerType *Int8PtrTy = Type::getInt8PtrTy( *Ctx, cast(Adj->Ty)->getAddressSpace()); - Base = new BitCastInst(Base, Int8PtrTy, "base_bitcast", InsertionPt); + Base = new BitCastInst(Base, Int8PtrTy, "base_bitcast", Adj->MatInsertPt); Mat = GetElementPtrInst::Create(Type::getInt8Ty(*Ctx), Base, Adj->Offset, - "mat_gep", InsertionPt); - Mat = new BitCastInst(Mat, Adj->Ty, "mat_bitcast", InsertionPt); + "mat_gep", Adj->MatInsertPt); + Mat = new BitCastInst(Mat, Adj->Ty, "mat_bitcast", Adj->MatInsertPt); } else // Constant being rebased is a ConstantInt. Mat = BinaryOperator::Create(Instruction::Add, Base, Adj->Offset, - "const_mat", InsertionPt); + "const_mat", Adj->MatInsertPt); LLVM_DEBUG(dbgs() << "Materialize constant (" << *Base->getOperand(0) << " + " << *Adj->Offset << ") in BB " @@ -814,8 +821,7 @@ // Aside from constant GEPs, only constant cast expressions are collected. assert(ConstExpr->isCast() && "ConstExpr should be a cast"); - Instruction *ConstExprInst = ConstExpr->getAsInstruction( - findMatInsertPt(Adj->User.Inst, Adj->User.OpndIdx)); + Instruction *ConstExprInst = ConstExpr->getAsInstruction(Adj->MatInsertPt); ConstExprInst->setOperand(0, Mat); // Use the same debug location as the instruction we are about to update. @@ -840,8 +846,11 @@ bool MadeChange = false; SmallVectorImpl &ConstInfoVec = BaseGV ? ConstGEPInfoMap[BaseGV] : ConstIntInfoVec; - for (auto const &ConstInfo : ConstInfoVec) { - SetVector IPSet = findConstantInsertionPoint(ConstInfo); + for (const consthoist::ConstantInfo &ConstInfo : ConstInfoVec) { + SmallVector MatInsertPts; + collectMatInsertPts(ConstInfo.RebasedConstants, MatInsertPts); + SetVector IPSet = + findConstantInsertionPoint(ConstInfo, MatInsertPts); // We can have an empty set if the function contains unreachable blocks. if (IPSet.empty()) continue; @@ -853,16 +862,17 @@ // First, collect constants depending on this IP of the base. UsesNum = 0; SmallVector ToBeRebased; + unsigned MatCtr = 0; for (auto const &RCI : ConstInfo.RebasedConstants) { UsesNum += RCI.Uses.size(); for (auto const &U : RCI.Uses) { - BasicBlock *OrigMatInsertBB = - findMatInsertPt(U.Inst, U.OpndIdx)->getParent(); + Instruction *MatInsertPt = MatInsertPts[MatCtr++]; + BasicBlock *OrigMatInsertBB = MatInsertPt->getParent(); // If Base constant is to be inserted in multiple places, // generate rebase for U using the Base dominating U. if (IPSet.size() == 1 || DT->dominates(IP->getParent(), OrigMatInsertBB)) - ToBeRebased.emplace_back(RCI.Offset, RCI.Ty, U); + ToBeRebased.emplace_back(RCI.Offset, RCI.Ty, MatInsertPt, U); } }