Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -306,21 +306,30 @@ /// The very first instruction in the list with the main opcode. Value *OpValue = nullptr; + /// The main/alternate instruction. + Instruction *MainOp = nullptr; + Instruction *AltOp = nullptr; + /// The main/alternate opcodes for the list of instructions. - unsigned Opcode = 0; - unsigned AltOpcode = 0; + unsigned getOpcode() const { + return MainOp ? MainOp->getOpcode() : 0; + } + + unsigned getAltOpcode() const { + return AltOp ? AltOp->getOpcode() : 0; + } /// Some of the instructions in the list have alternate opcodes. - bool isAltShuffle() const { return Opcode != AltOpcode; } + bool isAltShuffle() const { return getOpcode() != getAltOpcode(); } bool isOpcodeOrAlt(Instruction *I) const { unsigned CheckedOpcode = I->getOpcode(); - return Opcode == CheckedOpcode || AltOpcode == CheckedOpcode; + return getOpcode() == CheckedOpcode || getAltOpcode() == CheckedOpcode; } - InstructionsState() = default; - InstructionsState(Value *OpValue, unsigned Opcode, unsigned AltOpcode) - : OpValue(OpValue), Opcode(Opcode), AltOpcode(AltOpcode) {} + InstructionsState() = delete; + InstructionsState(Value *OpValue, Instruction *MainOp, Instruction *AltOp) + : OpValue(OpValue), MainOp(MainOp), AltOp(AltOp) {} }; } // end anonymous namespace @@ -342,11 +351,12 @@ unsigned BaseIndex = 0) { // Make sure these are all Instructions. if (llvm::any_of(VL, [](Value *V) { return !isa(V); })) - return InstructionsState(VL[BaseIndex], 0, 0); + return InstructionsState(VL[BaseIndex], nullptr, nullptr); bool IsBinOp = isa(VL[BaseIndex]); unsigned Opcode = cast(VL[BaseIndex])->getOpcode(); unsigned AltOpcode = Opcode; + unsigned AltIndex = BaseIndex; // Check for one alternate opcode from another BinaryOperator. // TODO - can we support other operators (casts etc.)? @@ -355,13 +365,15 @@ if (InstOpcode != Opcode && InstOpcode != AltOpcode) { if (Opcode == AltOpcode && IsBinOp && isa(VL[Cnt])) { AltOpcode = InstOpcode; + AltIndex = Cnt; continue; } - return InstructionsState(VL[BaseIndex], 0, 0); + return InstructionsState(VL[BaseIndex], nullptr, nullptr); } } - return InstructionsState(VL[BaseIndex], Opcode, AltOpcode); + return InstructionsState(VL[BaseIndex], cast(VL[BaseIndex]), + cast(VL[AltIndex])); } /// \returns true if all of the values in \p VL have the same type or false @@ -1390,7 +1402,7 @@ } // If all of the operands are identical or constant we have a simple solution. - if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !S.Opcode) { + if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !S.getOpcode()) { LLVM_DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O. \n"); newTreeEntry(VL, false, UserTreeIdx); return; @@ -1500,7 +1512,7 @@ LLVM_DEBUG(dbgs() << "SLP: We are able to schedule this bundle.\n"); unsigned ShuffleOrOp = S.isAltShuffle() ? - (unsigned) Instruction::ShuffleVector : S.Opcode; + (unsigned) Instruction::ShuffleVector : S.getOpcode(); switch (ShuffleOrOp) { case Instruction::PHI: { PHINode *PH = dyn_cast(VL0); @@ -1734,7 +1746,7 @@ // have the same opcode. if (isa(VL0) && VL0->isCommutative()) { ValueList Left, Right; - reorderInputsAccordingToOpcode(S.Opcode, VL, Left, Right); + reorderInputsAccordingToOpcode(S.getOpcode(), VL, Left, Right); buildTree_rec(Left, Depth + 1, UserTreeIdx); buildTree_rec(Right, Depth + 1, UserTreeIdx); return; @@ -1952,7 +1964,7 @@ Instruction *E0 = cast(OpValue); assert(E0->getOpcode() == Instruction::ExtractElement || E0->getOpcode() == Instruction::ExtractValue); - assert(E0->getOpcode() == getSameOpcode(VL).Opcode && "Invalid opcode"); + assert(E0->getOpcode() == getSameOpcode(VL).getOpcode() && "Invalid opcode"); // Check if all of the extracts come from the same vector and from the // correct offset. Value *Vec = E0->getOperand(0); @@ -2051,7 +2063,7 @@ return ReuseShuffleCost + TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, 0); } - if (getSameOpcode(VL).Opcode == Instruction::ExtractElement && + if (getSameOpcode(VL).getOpcode() == Instruction::ExtractElement && allSameType(VL) && allSameBlock(VL)) { Optional ShuffleKind = isShuffle(VL); if (ShuffleKind.hasValue()) { @@ -2075,10 +2087,10 @@ return ReuseShuffleCost + getGatherCost(VL); } InstructionsState S = getSameOpcode(VL); - assert(S.Opcode && allSameType(VL) && allSameBlock(VL) && "Invalid VL"); + assert(S.getOpcode() && allSameType(VL) && allSameBlock(VL) && "Invalid VL"); Instruction *VL0 = cast(S.OpValue); unsigned ShuffleOrOp = S.isAltShuffle() ? - (unsigned) Instruction::ShuffleVector : S.Opcode; + (unsigned) Instruction::ShuffleVector : S.getOpcode(); switch (ShuffleOrOp) { case Instruction::PHI: return 0; @@ -2166,7 +2178,7 @@ if (NeedToShuffleReuses) { ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * - TTI->getCastInstrCost(S.Opcode, ScalarTy, SrcTy, VL0); + TTI->getCastInstrCost(S.getOpcode(), ScalarTy, SrcTy, VL0); } // Calculate the cost of this instruction. @@ -2188,13 +2200,14 @@ // Calculate the cost of this instruction. if (NeedToShuffleReuses) { ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * - TTI->getCmpSelInstrCost(S.Opcode, ScalarTy, + TTI->getCmpSelInstrCost(S.getOpcode(), ScalarTy, Builder.getInt1Ty(), VL0); } VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size()); int ScalarCost = VecTy->getNumElements() * - TTI->getCmpSelInstrCost(S.Opcode, ScalarTy, Builder.getInt1Ty(), VL0); - int VecCost = TTI->getCmpSelInstrCost(S.Opcode, VecTy, MaskTy, VL0); + TTI->getCmpSelInstrCost(S.getOpcode(), ScalarTy, + Builder.getInt1Ty(), VL0); + int VecCost = TTI->getCmpSelInstrCost(S.getOpcode(), VecTy, MaskTy, VL0); return ReuseShuffleCost + VecCost - ScalarCost; } case Instruction::Add: @@ -2255,15 +2268,15 @@ if (NeedToShuffleReuses) { ReuseShuffleCost -= (ReuseShuffleNumbers - VL.size()) * - TTI->getArithmeticInstrCost(S.Opcode, ScalarTy, Op1VK, Op2VK, Op1VP, - Op2VP, Operands); + TTI->getArithmeticInstrCost(S.getOpcode(), ScalarTy, Op1VK, Op2VK, + Op1VP, Op2VP, Operands); } int ScalarCost = VecTy->getNumElements() * - TTI->getArithmeticInstrCost(S.Opcode, ScalarTy, Op1VK, Op2VK, Op1VP, - Op2VP, Operands); - int VecCost = TTI->getArithmeticInstrCost(S.Opcode, VecTy, Op1VK, Op2VK, - Op1VP, Op2VP, Operands); + TTI->getArithmeticInstrCost(S.getOpcode(), ScalarTy, Op1VK, Op2VK, + Op1VP, Op2VP, Operands); + int VecCost = TTI->getArithmeticInstrCost(S.getOpcode(), VecTy, Op1VK, + Op2VK, Op1VP, Op2VP, Operands); return ReuseShuffleCost + VecCost - ScalarCost; } case Instruction::GetElementPtr: { @@ -2350,8 +2363,8 @@ return ReuseShuffleCost + VecCallCost - ScalarCallCost; } case Instruction::ShuffleVector: { - assert(S.isAltShuffle() && Instruction::isBinaryOp(S.Opcode) && - Instruction::isBinaryOp(S.AltOpcode) && + assert(S.isAltShuffle() && Instruction::isBinaryOp(S.getOpcode()) && + Instruction::isBinaryOp(S.getAltOpcode()) && "Invalid Shuffle Vector Operand"); int ScalarCost = 0; if (NeedToShuffleReuses) { @@ -2374,8 +2387,8 @@ } // VecCost is equal to sum of the cost of creating 2 vectors // and the cost of creating shuffle. - VecCost = TTI->getArithmeticInstrCost(S.Opcode, VecTy); - VecCost += TTI->getArithmeticInstrCost(S.AltOpcode, VecTy); + VecCost = TTI->getArithmeticInstrCost(S.getOpcode(), VecTy); + VecCost += TTI->getArithmeticInstrCost(S.getAltOpcode(), VecTy); VecCost += TTI->getShuffleCost(TargetTransformInfo::SK_Select, VecTy, 0); return ReuseShuffleCost + VecCost - ScalarCost; } @@ -2923,7 +2936,7 @@ Value *BoUpSLP::vectorizeTree(ArrayRef VL) { InstructionsState S = getSameOpcode(VL); - if (S.Opcode) { + if (S.getOpcode()) { if (TreeEntry *E = getTreeEntry(S.OpValue)) { if (E->isSame(VL)) { Value *V = vectorizeTree(E); @@ -3026,7 +3039,7 @@ } unsigned ShuffleOrOp = S.isAltShuffle() ? - (unsigned) Instruction::ShuffleVector : S.Opcode; + (unsigned) Instruction::ShuffleVector : S.getOpcode(); switch (ShuffleOrOp) { case Instruction::PHI: { PHINode *PH = dyn_cast(VL0); @@ -3191,7 +3204,7 @@ CmpInst::Predicate P0 = cast(VL0)->getPredicate(); Value *V; - if (S.Opcode == Instruction::FCmp) + if (S.getOpcode() == Instruction::FCmp) V = Builder.CreateFCmp(P0, L, R); else V = Builder.CreateICmp(P0, L, R); @@ -3253,7 +3266,7 @@ case Instruction::Xor: { ValueList LHSVL, RHSVL; if (isa(VL0) && VL0->isCommutative()) - reorderInputsAccordingToOpcode(S.Opcode, E->Scalars, LHSVL, + reorderInputsAccordingToOpcode(S.getOpcode(), E->Scalars, LHSVL, RHSVL); else for (Value *V : E->Scalars) { @@ -3273,7 +3286,7 @@ } Value *V = Builder.CreateBinOp( - static_cast(S.Opcode), LHS, RHS); + static_cast(S.getOpcode()), LHS, RHS); propagateIRFlags(V, E->Scalars, VL0); if (auto *I = dyn_cast(V)) V = propagateMetadata(I, E->Scalars); @@ -3457,8 +3470,8 @@ } case Instruction::ShuffleVector: { ValueList LHSVL, RHSVL; - assert(S.isAltShuffle() && Instruction::isBinaryOp(S.Opcode) && - Instruction::isBinaryOp(S.AltOpcode) && + assert(S.isAltShuffle() && Instruction::isBinaryOp(S.getOpcode()) && + Instruction::isBinaryOp(S.getAltOpcode()) && "Invalid Shuffle Vector Operand"); reorderAltShuffleOperands(S, E->Scalars, LHSVL, RHSVL); setInsertPointAfterBundle(E->Scalars, S); @@ -3473,11 +3486,11 @@ // Create a vector of LHS op1 RHS Value *V0 = Builder.CreateBinOp( - static_cast(S.Opcode), LHS, RHS); + static_cast(S.getOpcode()), LHS, RHS); // Create a vector of LHS op2 RHS Value *V1 = Builder.CreateBinOp( - static_cast(S.AltOpcode), LHS, RHS); + static_cast(S.getAltOpcode()), LHS, RHS); // Create shuffle to take alternate operations from the vector. // Also, gather up odd and even scalar ops to propagate IR flags to @@ -3488,7 +3501,7 @@ for (unsigned i = 0; i < e; ++i) { auto *OpInst = cast(E->Scalars[i]); assert(S.isOpcodeOrAlt(OpInst) && "Unexpected main/alternate opcode"); - if (OpInst->getOpcode() == S.AltOpcode) { + if (OpInst->getOpcode() == S.getAltOpcode()) { Mask[i] = Builder.getInt32(e + i); AltScalars.push_back(E->Scalars[i]); } else { @@ -4817,7 +4830,7 @@ // Check that all of the parts are scalar instructions of the same type, // we permit an alternate opcode via InstructionsState. InstructionsState S = getSameOpcode(VL); - if (!S.Opcode) + if (!S.getOpcode()) return false; Instruction *I0 = cast(S.OpValue);