diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -6584,6 +6584,7 @@ &ILV, CallbackILV}; State.CFG.PrevBB = ILV.createVectorizedLoopSkeleton(); State.TripCount = ILV.getOrCreateTripCount(nullptr); + State.PrimaryIV = ILV.Legal->getPrimaryInduction(); //===------------------------------------------------===// // @@ -6770,7 +6771,7 @@ // Introduce the early-exit compare IV <= BTC to form header block mask. // This is used instead of IV < TC because TC may wrap, unlike BTC. - VPValue *IV = Plan->getVPValue(Legal->getPrimaryInduction()); + VPValue *IV = Plan->getOrCreatePrimaryIV(); VPValue *BTC = Plan->getOrCreateBackedgeTakenCount(); BlockMask = Builder.createNaryOp(VPInstruction::ICmpULE, {IV, BTC}); return BlockMaskCache[BB] = BlockMask; diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -332,6 +332,8 @@ /// Hold the trip count of the scalar loop. Value *TripCount = nullptr; + Value *PrimaryIV = nullptr; + /// Hold a pointer to InnerLoopVectorizer to reuse its IR generation methods. InnerLoopVectorizer *ILV; @@ -1455,6 +1457,9 @@ /// the tail. VPValue *BackedgeTakenCount = nullptr; + /// Represents the primary IV of the original loop. + VPValue *PrimaryIV = nullptr; + /// Holds a mapping between Values and their corresponding VPValue inside /// VPlan. Value2VPValueTy Value2VPValue; @@ -1475,10 +1480,12 @@ if (Entry) VPBlockBase::deleteCFG(Entry); for (auto &MapEntry : Value2VPValue) - if (MapEntry.second != BackedgeTakenCount) + if (MapEntry.second != BackedgeTakenCount && MapEntry.second != PrimaryIV) delete MapEntry.second; if (BackedgeTakenCount) delete BackedgeTakenCount; // Delete once, if in Value2VPValue or not. + if (PrimaryIV) + delete PrimaryIV; // Delete once, if in Value2VPValue or not. for (VPValue *Def : VPExternalDefs) delete Def; for (VPValue *CBV : VPCBVs) @@ -1504,6 +1511,12 @@ return BackedgeTakenCount; } + VPValue *getOrCreatePrimaryIV() { + if (!PrimaryIV) + PrimaryIV = new VPValue(); + return PrimaryIV; + } + void addVF(unsigned VF) { VFs.insert(VF); } bool hasVF(unsigned VF) { return VFs.count(VF); } diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -444,6 +444,8 @@ "trip.count.minus.1"); Value2VPValue[TCMO] = BackedgeTakenCount; } + if (PrimaryIV && PrimaryIV->getNumUsers()) + Value2VPValue[State->PrimaryIV] = PrimaryIV; // 0. Set the reverse mapping from VPValues to Values for code generation. for (auto &Entry : Value2VPValue)