diff --git a/llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h b/llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h --- a/llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h +++ b/llvm/include/llvm/Transforms/Vectorize/SLPVectorizer.h @@ -144,14 +144,14 @@ slpvectorizer::BoUpSLP &R); /// Tries to vectorize \p CmpInts. \Returns true on success. - bool vectorizeCmpInsts(ArrayRef CmpInsts, BasicBlock *BB, + template + bool vectorizeCmpInsts(iterator_range CmpInsts, BasicBlock *BB, slpvectorizer::BoUpSLP &R); - /// Tries to vectorize constructs started from CmpInst, InsertValueInst or + /// Tries to vectorize constructs started from InsertValueInst or /// InsertElementInst instructions. - bool vectorizeSimpleInstructions(InstSetVector &Instructions, BasicBlock *BB, - slpvectorizer::BoUpSLP &R, - bool AtTerminator); + bool vectorizeInserts(InstSetVector &Instructions, BasicBlock *BB, + slpvectorizer::BoUpSLP &R); /// Scan the basic block and look for patterns that are likely to start /// a vectorization chain. diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -14399,11 +14399,12 @@ return IsCompatibility; } -bool SLPVectorizerPass::vectorizeCmpInsts(ArrayRef CmpInsts, +template +bool SLPVectorizerPass::vectorizeCmpInsts(iterator_range CmpInsts, BasicBlock *BB, BoUpSLP &R) { bool Changed = false; // Try to find reductions first. - for (Instruction *I : CmpInsts) { + for (CmpInst *I : CmpInsts) { if (R.isDeleted(I)) continue; for (Value *Op : I->operands()) @@ -14411,7 +14412,7 @@ Changed |= vectorizeRootInstruction(nullptr, RootOp, BB, R, TTI); } // Try to vectorize operands as vector bundles. - for (Instruction *I : CmpInsts) { + for (CmpInst *I : CmpInsts) { if (R.isDeleted(I)) continue; Changed |= tryToVectorize(I, R); @@ -14450,25 +14451,19 @@ return Changed; } -bool SLPVectorizerPass::vectorizeSimpleInstructions(InstSetVector &Instructions, - BasicBlock *BB, BoUpSLP &R, - bool AtTerminator) { +bool SLPVectorizerPass::vectorizeInserts(InstSetVector &Instructions, + BasicBlock *BB, BoUpSLP &R) { assert(all_of(Instructions, [](auto *I) { - return isa(I); + return isa(I); }) && - "This function only accepts Cmp and Insert instructions"); + "This function only accepts Insert instructions"); bool OpsChanged = false; - SmallVector PostponedCmps; SmallVector PostponedInsts; // pass1 - try to vectorize reductions only for (auto *I : reverse(Instructions)) { if (R.isDeleted(I)) continue; - if (isa(I)) { - PostponedCmps.push_back(cast(I)); - continue; - } OpsChanged |= vectorizeHorReduction(nullptr, I, BB, R, TTI, PostponedInsts); } // pass2 - try to match and vectorize a buildvector sequence. @@ -14484,15 +14479,7 @@ // Now try to vectorize postponed instructions. OpsChanged |= tryToVectorize(PostponedInsts, R); - if (AtTerminator) { - OpsChanged |= vectorizeCmpInsts(PostponedCmps, BB, R); - Instructions.clear(); - } else { - Instructions.clear(); - // Insert in reverse order since the PostponedCmps vector was filled in - // reverse order. - Instructions.insert(PostponedCmps.rbegin(), PostponedCmps.rend()); - } + Instructions.clear(); return OpsChanged; } @@ -14640,8 +14627,27 @@ VisitedInstrs.clear(); - InstSetVector PostProcessInstructions; + InstSetVector PostProcessInserts; + SmallSetVector PostProcessCmps; SmallDenseSet KeyNodes; + // Vectorizes Inserts in `PostProcessInserts` and if `VecctorizeCmps` is true + // also vectorizes `PostProcessCmps`. + auto VectorizeInsertsAndCmps = [&](bool VectorizeCmps) { + bool Changed = vectorizeInserts(PostProcessInserts, BB, R); + if (VectorizeCmps) { + Changed |= vectorizeCmpInsts(reverse(PostProcessCmps), BB, R); + PostProcessCmps.clear(); + } + PostProcessInserts.clear(); + return Changed; + }; + // Returns true if `I` is in `PostProcessInserts` or `PostProcessCmps`. + auto IsInPostProcessInstrs = [&](Instruction *I) { + if (auto *Cmp = dyn_cast(I)) + return PostProcessCmps.contains(Cmp); + return isa(I) && + PostProcessInserts.contains(I); + }; for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { // Skip instructions with scalable type. The num of elements is unknown at // compile-time for scalable type. @@ -14654,8 +14660,7 @@ // We may go through BB multiple times so skip the one we have checked. if (!VisitedInstrs.insert(&*it).second) { if (it->use_empty() && KeyNodes.contains(&*it) && - vectorizeSimpleInstructions(PostProcessInstructions, BB, R, - it->isTerminator())) { + VectorizeInsertsAndCmps(/*VectorizeCmps=*/it->isTerminator())) { // We would like to start over since some instructions are deleted // and the iterator may become invalid value. Changed = true; @@ -14695,7 +14700,7 @@ // Postponed instructions should not be vectorized here, delay their // vectorization. if (auto *PI = dyn_cast(P->getIncomingValue(I)); - PI && !PostProcessInstructions.contains(PI)) + PI && !IsInPostProcessInstrs(PI)) Changed |= vectorizeRootInstruction(nullptr, PI, P->getIncomingBlock(I), R, TTI); } @@ -14727,7 +14732,7 @@ // Postponed instructions should not be vectorized here, delay their // vectorization. if (auto *VI = dyn_cast(V); - VI && !PostProcessInstructions.contains(VI)) + VI && !IsInPostProcessInstrs(VI)) // Try to match and vectorize a horizontal reduction. OpsChanged |= vectorizeRootInstruction(nullptr, VI, BB, R, TTI); } @@ -14735,8 +14740,8 @@ // Start vectorization of post-process list of instructions from the // top-tree instructions to try to vectorize as many instructions as // possible. - OpsChanged |= vectorizeSimpleInstructions(PostProcessInstructions, BB, R, - it->isTerminator()); + OpsChanged |= + VectorizeInsertsAndCmps(/*VectorizeCmps=*/it->isTerminator()); if (OpsChanged) { // We would like to start over since some instructions are deleted // and the iterator may become invalid value. @@ -14747,8 +14752,10 @@ } } - if (isa(it)) - PostProcessInstructions.insert(&*it); + if (isa(it)) + PostProcessInserts.insert(&*it); + else if (isa(it)) + PostProcessCmps.insert(cast(&*it)); } return Changed;