diff --git a/llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h b/llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h --- a/llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h +++ b/llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h @@ -145,14 +145,14 @@ slpvectorizer::BoUpSLP &R); /// Tries to vectorize \p CmpInts. \Returns true on success. - bool vectorizeCmpInsts(ArrayRef CmpInsts, BasicBlock *BB, + template + bool vectorizeCmpInsts(RangeT CmpInsts, BasicBlock *BB, slpvectorizer::BoUpSLP &R); - /// Tries to vectorize constructs started from CmpInst, InsertValueInst or + /// Tries to vectorize constructs started from InsertValueInst or /// InsertElementInst instructions. - bool vectorizeSimpleInstructions(InstSetVector &Instructions, BasicBlock *BB, - slpvectorizer::BoUpSLP &R, - bool AtTerminator); + bool vectorizeInserts(InstSetVector &Instructions, BasicBlock *BB, + slpvectorizer::BoUpSLP &R); /// Scan the basic block and look for patterns that are likely to start /// a vectorization chain. diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -14392,8 +14392,9 @@ return IsCompatibility; } -bool SLPVectorizerPass::vectorizeCmpInsts(ArrayRef CmpInsts, - BasicBlock *BB, BoUpSLP &R) { +template +bool SLPVectorizerPass::vectorizeCmpInsts(RangeT CmpInsts, BasicBlock *BB, + BoUpSLP &R) { bool Changed = false; // Try to find reductions first. for (Instruction *I : CmpInsts) { @@ -14443,25 +14444,19 @@ return Changed; } -bool SLPVectorizerPass::vectorizeSimpleInstructions(InstSetVector &Instructions, - BasicBlock *BB, BoUpSLP &R, - bool AtTerminator) { +bool SLPVectorizerPass::vectorizeInserts(InstSetVector &Instructions, + BasicBlock *BB, BoUpSLP &R) { assert(all_of(Instructions, [](auto *I) { - return isa(I); + return isa(I); }) && - "This function only accepts Cmp and Insert instructions"); + "This function only accepts Insert instructions"); bool OpsChanged = false; - SmallVector PostponedCmps; SmallVector PostponedInsts; // pass1 - try to vectorize reductions only for (auto *I : reverse(Instructions)) { if (R.isDeleted(I)) continue; - if (isa(I)) { - PostponedCmps.push_back(cast(I)); - continue; - } OpsChanged |= vectorizeHorReduction(nullptr, I, BB, R, TTI, PostponedInsts); } // pass2 - try to match and vectorize a buildvector sequence. @@ -14477,15 +14472,7 @@ // Now try to vectorize postponed instructions. OpsChanged |= tryToVectorize(PostponedInsts, R); - if (AtTerminator) { - OpsChanged |= vectorizeCmpInsts(PostponedCmps, BB, R); - Instructions.clear(); - } else { - Instructions.clear(); - // Insert in reverse order since the PostponedCmps vector was filled in - // reverse order. - Instructions.insert(PostponedCmps.rbegin(), PostponedCmps.rend()); - } + Instructions.clear(); return OpsChanged; } @@ -14633,8 +14620,26 @@ VisitedInstrs.clear(); - InstSetVector PostProcessInstructions; + InstSetVector PostProcessInserts; + SmallSetVector PostProcessCmps; SmallDenseSet KeyNodes; + // Vectorizes Inserts in `PostProcessInserts` and if `VecctorizeCmps` is true + // also vectorizes `PostProcessCmps`. + auto VectorizeInsertsAndCmps = [&](bool VectorizeCmps) { + bool Changed = vectorizeInserts(PostProcessInserts, BB, R); + if (VectorizeCmps) { + Changed |= vectorizeCmpInsts(reverse(PostProcessCmps), BB, R); + PostProcessCmps.clear(); + } + PostProcessInserts.clear(); + return Changed; + }; + // Returns true if `I` is in `PostProcessInserts` or `PostProcessCmps`. + auto IsInPostProcessInstrs = [&](Instruction *I) { + return (isa(I) && + PostProcessInserts.contains(I)) || + (isa(I) && PostProcessCmps.contains(cast(I))); + }; for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { // Skip instructions with scalable type. The num of elements is unknown at // compile-time for scalable type. @@ -14647,8 +14652,7 @@ // We may go through BB multiple times so skip the one we have checked. if (!VisitedInstrs.insert(&*it).second) { if (it->use_empty() && KeyNodes.contains(&*it) && - vectorizeSimpleInstructions(PostProcessInstructions, BB, R, - it->isTerminator())) { + VectorizeInsertsAndCmps(/*VectorizeCmps=*/it->isTerminator())) { // We would like to start over since some instructions are deleted // and the iterator may become invalid value. Changed = true; @@ -14688,7 +14692,7 @@ // Postponed instructions should not be vectorized here, delay their // vectorization. if (auto *PI = dyn_cast(P->getIncomingValue(I)); - PI && !PostProcessInstructions.contains(PI)) + PI && !IsInPostProcessInstrs(PI)) Changed |= vectorizeRootInstruction(nullptr, PI, P->getIncomingBlock(I), R, TTI); } @@ -14720,7 +14724,7 @@ // Postponed instructions should not be vectorized here, delay their // vectorization. if (auto *VI = dyn_cast(V); - VI && !PostProcessInstructions.contains(VI)) + VI && !IsInPostProcessInstrs(VI)) // Try to match and vectorize a horizontal reduction. OpsChanged |= vectorizeRootInstruction(nullptr, VI, BB, R, TTI); } @@ -14728,8 +14732,8 @@ // Start vectorization of post-process list of instructions from the // top-tree instructions to try to vectorize as many instructions as // possible. - OpsChanged |= vectorizeSimpleInstructions(PostProcessInstructions, BB, R, - it->isTerminator()); + OpsChanged |= + VectorizeInsertsAndCmps(/*VectorizeCmps=*/it->isTerminator()); if (OpsChanged) { // We would like to start over since some instructions are deleted // and the iterator may become invalid value. @@ -14740,8 +14744,10 @@ } } - if (isa(it)) - PostProcessInstructions.insert(&*it); + if (isa(it)) + PostProcessInserts.insert(&*it); + else if (isa(it)) + PostProcessCmps.insert(cast(&*it)); } return Changed;