Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -178,6 +178,26 @@ return Value & 1; } +static bool sameOpcodeOrAlt(unsigned Opcode, unsigned AltOpcode, + unsigned CheckedOpcode) { + return Opcode == CheckedOpcode || AltOpcode == CheckedOpcode; +} + +/// Chooses the correct key for scheduling data. If \p Op has the same (or +/// alternate) opcode as \p OpValue, the key is \p Op. Otherwise the key is \p +/// OpValue. +static Value *isOneOf(Value *OpValue, Value *Op) { + auto *I = dyn_cast(Op); + if (!I) + return OpValue; + auto *OpInst = cast(OpValue); + unsigned OpInstOpcode = OpInst->getOpcode(); + unsigned IOpcode = I->getOpcode(); + if (sameOpcodeOrAlt(OpInstOpcode, getAltOpcode(OpInstOpcode), IOpcode)) + return Op; + return OpValue; +} + ///\returns bool representing if Opcode \p Op can be part /// of an alternate sequence which can later be merged as /// a ShuffleVector instruction. @@ -447,7 +467,7 @@ /// \brief Set the Builder insert point to one after the last instruction in /// the bundle - void setInsertPointAfterBundle(ArrayRef VL); + void setInsertPointAfterBundle(ArrayRef VL, Value *OpValue); /// \returns a vector from a collection of scalars in \p VL. Value *Gather(ArrayRef VL, VectorType *Ty); @@ -633,9 +653,10 @@ : Inst(nullptr), FirstInBundle(nullptr), NextInBundle(nullptr), NextLoadStore(nullptr), SchedulingRegionID(0), SchedulingPriority(0), Dependencies(InvalidDeps), UnscheduledDeps(InvalidDeps), - UnscheduledDepsInBundle(InvalidDeps), IsScheduled(false) {} + UnscheduledDepsInBundle(InvalidDeps), IsScheduled(false), + OpValue(nullptr) {} - void init(int BlockSchedulingRegionID) { + void init(int BlockSchedulingRegionID, Value *OpVal) { FirstInBundle = this; NextInBundle = nullptr; NextLoadStore = nullptr; @@ -643,6 +664,7 @@ SchedulingRegionID = BlockSchedulingRegionID; UnscheduledDepsInBundle = UnscheduledDeps; clearDependencies(); + OpValue = OpVal; } /// Returns true if the dependency information has been calculated. @@ -747,6 +769,9 @@ /// True if this instruction is scheduled (or considered as scheduled in the /// dry-run). bool IsScheduled; + + /// Opcode of the current instruction in the schedule data. + Value *OpValue; }; #ifndef NDEBUG @@ -2336,14 +2361,18 @@ } } -void BoUpSLP::setInsertPointAfterBundle(ArrayRef VL) { +void BoUpSLP::setInsertPointAfterBundle(ArrayRef VL, Value *OpValue) { // Get the basic block this bundle is in. All instructions in the bundle // should be in this block. - auto *Front = cast(VL.front()); + auto *Front = cast(OpValue); auto *BB = Front->getParent(); - assert(all_of(make_range(VL.begin(), VL.end()), [&](Value *V) -> bool { - return cast(V)->getParent() == BB; + const unsigned Opcode = cast(OpValue)->getOpcode(); + const unsigned AltOpcode = getAltOpcode(Opcode); + assert(all_of(make_range(VL.begin(), VL.end()), [=](Value *V) -> bool { + return !sameOpcodeOrAlt(Opcode, AltOpcode, + cast(V)->getOpcode()) || + cast(V)->getParent() == BB; })); // The last instruction in the bundle in program order. @@ -2354,10 +2383,12 @@ // VL.back() and iterate over schedule data until we reach the end of the // bundle. The end of the bundle is marked by null ScheduleData. if (BlocksSchedules.count(BB)) { - auto *Bundle = BlocksSchedules[BB]->getScheduleData(VL.back()); + auto *Bundle = + BlocksSchedules[BB]->getScheduleData(isOneOf(OpValue, VL.back())); if (Bundle && Bundle->isPartOfBundle()) for (; Bundle; Bundle = Bundle->NextInBundle) - LastInst = Bundle->Inst; + if (Bundle->OpValue == Bundle->Inst) + LastInst = Bundle->Inst; } // LastInst can still be null at this point if there's either not an entry @@ -2381,7 +2412,7 @@ if (!LastInst) { SmallPtrSet Bundle(VL.begin(), VL.end()); for (auto &I : make_range(BasicBlock::iterator(Front), BB->end())) { - if (Bundle.erase(&I)) + if (Bundle.erase(&I) && sameOpcodeOrAlt(Opcode, AltOpcode, I.getOpcode())) LastInst = &I; if (Bundle.empty()) break; @@ -2459,7 +2490,7 @@ VectorType *VecTy = VectorType::get(ScalarTy, E->Scalars.size()); if (E->NeedToGather) { - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); auto *V = Gather(E->Scalars, VecTy); E->VectorizedValue = V; return V; @@ -2509,7 +2540,7 @@ E->VectorizedValue = V; return V; } - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); auto *V = Gather(E->Scalars, VecTy); E->VectorizedValue = V; return V; @@ -2524,7 +2555,7 @@ E->VectorizedValue = V; return propagateMetadata(V, E->Scalars); } - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); auto *V = Gather(E->Scalars, VecTy); E->VectorizedValue = V; return V; @@ -2545,7 +2576,7 @@ for (Value *V : E->Scalars) INVL.push_back(cast(V)->getOperand(0)); - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); Value *InVec = vectorizeTree(INVL); @@ -2566,7 +2597,7 @@ RHSV.push_back(cast(V)->getOperand(1)); } - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); Value *L = vectorizeTree(LHSV); Value *R = vectorizeTree(RHSV); @@ -2594,7 +2625,7 @@ FalseVec.push_back(cast(V)->getOperand(2)); } - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); Value *Cond = vectorizeTree(CondVec); Value *True = vectorizeTree(TrueVec); @@ -2635,7 +2666,7 @@ RHSVL.push_back(cast(V)->getOperand(1)); } - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); Value *LHS = vectorizeTree(LHSVL); Value *RHS = vectorizeTree(RHSVL); @@ -2657,7 +2688,7 @@ case Instruction::Load: { // Loads are inserted at the head of the tree because we don't want to // sink them all the way down past store instructions. - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); LoadInst *LI = cast(VL0); Type *ScalarLoadTy = LI->getType(); @@ -2692,7 +2723,7 @@ for (Value *V : E->Scalars) ValueOp.push_back(cast(V)->getValueOperand()); - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); Value *VecValue = vectorizeTree(ValueOp); Value *VecPtr = Builder.CreateBitCast(SI->getPointerOperand(), @@ -2715,7 +2746,7 @@ return propagateMetadata(S, E->Scalars); } case Instruction::GetElementPtr: { - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); ValueList Op0VL; for (Value *V : E->Scalars) @@ -2746,7 +2777,7 @@ } case Instruction::Call: { CallInst *CI = cast(VL0); - setInsertPointAfterBundle(VL0); + setInsertPointAfterBundle(E->Scalars, VL0); Function *FI; Intrinsic::ID IID = Intrinsic::not_intrinsic; Value *ScalarArg = nullptr; @@ -2797,7 +2828,7 @@ ValueList LHSVL, RHSVL; assert(isa(VL0) && "Invalid Shuffle Vector Operand"); reorderAltShuffleOperands(E->Scalars, LHSVL, RHSVL); - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); Value *LHS = vectorizeTree(LHSVL); Value *RHS = vectorizeTree(RHSVL); @@ -3279,7 +3310,7 @@ } assert(!isInSchedulingRegion(SD) && "new ScheduleData already in scheduling region"); - SD->init(SchedulingRegionID); + SD->init(SchedulingRegionID, I); if (I->mayReadOrWriteMemory()) { // Update the linked list of memory accessing instructions.