diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -6312,20 +6312,13 @@ /// Opcode of the instruction. unsigned Opcode = 0; - /// Left operand of the reduction operation. - Value *LHS = nullptr; - - /// Right operand of the reduction operation. - Value *RHS = nullptr; - /// Kind of the reduction operation. ReductionKind Kind = RK_None; /// Checks if the reduction operation can be vectorized. bool isVectorizable() const { - return LHS && RHS && - // We currently only support add/mul/logical && min/max reductions. - ((Kind == RK_Arithmetic && + // We currently only support add/mul/logical && min/max reductions. + return ((Kind == RK_Arithmetic && (Opcode == Instruction::Add || Opcode == Instruction::FAdd || Opcode == Instruction::Mul || Opcode == Instruction::FMul || Opcode == Instruction::And || Opcode == Instruction::Or || @@ -6336,7 +6329,8 @@ } /// Creates reduction operation with the current opcode. - Value *createOp(IRBuilder<> &Builder, const Twine &Name) const { + Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS, + const Twine &Name) const { assert(isVectorizable() && "Expected add|fadd or min/max reduction operation."); Value *Cmp = nullptr; @@ -6377,8 +6371,8 @@ /// Constructor for reduction operations with opcode and its left and /// right operands. - OperationData(unsigned Opcode, Value *LHS, Value *RHS, ReductionKind Kind) - : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) { + OperationData(unsigned Opcode, ReductionKind Kind) + : Opcode(Opcode), Kind(Kind) { assert(Kind != RK_None && "One of the reduction operations is expected."); } @@ -6411,16 +6405,14 @@ /// Total number of operands in the reduction operation. unsigned getNumberOfOperands() const { - assert(Kind != RK_None && !!*this && LHS && RHS && - "Expected reduction operation."); + assert(Kind != RK_None && !!*this && "Expected reduction operation."); return isMinMax() ? 3 : 2; } /// Checks if the instruction is in basic block \p BB. /// For a min/max reduction check that both compare and select are in \p BB. bool hasSameParent(Instruction *I, BasicBlock *BB, bool IsRedOp) const { - assert(Kind != RK_None && !!*this && LHS && RHS && - "Expected reduction operation."); + assert(Kind != RK_None && !!*this && "Expected reduction operation."); if (IsRedOp && isMinMax()) { auto *Cmp = cast(cast(I)->getCondition()); return I->getParent() == BB && Cmp && Cmp->getParent() == BB; @@ -6430,8 +6422,7 @@ /// Expected number of uses for reduction operations/reduced values. bool hasRequiredNumberOfUses(Instruction *I, bool IsReductionOp) const { - assert(Kind != RK_None && !!*this && LHS && RHS && - "Expected reduction operation."); + assert(Kind != RK_None && !!*this && "Expected reduction operation."); // SelectInst must be used twice while the condition op must have single // use only. if (isMinMax()) @@ -6445,8 +6436,7 @@ /// Initializes the list of reduction operations. void initReductionOps(ReductionOpsListType &ReductionOps) { - assert(Kind != RK_None && !!*this && LHS && RHS && - "Expected reduction operation."); + assert(Kind != RK_None && !!*this && "Expected reduction operation."); if (isMinMax()) ReductionOps.assign(2, ReductionOpsType()); else @@ -6455,8 +6445,7 @@ /// Add all reduction operations for the reduction instruction \p I. void addReductionOps(Instruction *I, ReductionOpsListType &ReductionOps) { - assert(Kind != RK_None && !!*this && LHS && RHS && - "Expected reduction operation."); + assert(Kind != RK_None && !!*this && "Expected reduction operation."); if (isMinMax()) { ReductionOps[0].emplace_back(cast(I)->getCondition()); ReductionOps[1].emplace_back(I); @@ -6467,8 +6456,7 @@ /// Checks if instruction is associative and can be vectorized. bool isAssociative(Instruction *I) const { - assert(Kind != RK_None && *this && LHS && RHS && - "Expected reduction operation."); + assert(Kind != RK_None && *this && "Expected reduction operation."); switch (Kind) { case RK_Arithmetic: return I->isAssociative(); @@ -6493,15 +6481,13 @@ /// Checks if two operation data are both a reduction op or both a reduced /// value. bool operator==(const OperationData &OD) const { - assert(((Kind != OD.Kind) || ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) && + assert(((Kind != OD.Kind) || (Opcode != 0 && OD.Opcode != 0)) && "One of the comparing operations is incorrect."); - return this == &OD || (Kind == OD.Kind && Opcode == OD.Opcode); + return Kind == OD.Kind && Opcode == OD.Opcode; } bool operator!=(const OperationData &OD) const { return !(*this == OD); } void clear() { Opcode = 0; - LHS = nullptr; - RHS = nullptr; Kind = RK_None; } @@ -6513,19 +6499,25 @@ /// Get kind of reduction data. ReductionKind getKind() const { return Kind; } - Value *getLHS() const { return LHS; } - Value *getRHS() const { return RHS; } - Type *getConditionType() const { - return isMinMax() ? CmpInst::makeCmpResultType(LHS->getType()) : nullptr; + Value *getLHS(Instruction *I) const { + if (Kind == RK_None) + return nullptr; + return I->getOperand(getFirstOperandIndex()); + } + Value *getRHS(Instruction *I) const { + if (Kind == RK_None) + return nullptr; + return I->getOperand(getFirstOperandIndex() + 1); } /// Creates reduction operation with the current opcode with the IR flags /// from \p ReductionOps. - Value *createOp(IRBuilder<> &Builder, const Twine &Name, + Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS, + const Twine &Name, const ReductionOpsListType &ReductionOps) const { assert(isVectorizable() && "Expected add|fadd or min/max reduction operation."); - auto *Op = createOp(Builder, Name); + auto *Op = createOp(Builder, LHS, RHS, Name); switch (Kind) { case RK_Arithmetic: propagateIRFlags(Op, ReductionOps[0]); @@ -6545,11 +6537,11 @@ } /// Creates reduction operation with the current opcode with the IR flags /// from \p I. - Value *createOp(IRBuilder<> &Builder, const Twine &Name, - Instruction *I) const { + Value *createOp(IRBuilder<> &Builder, Value *LHS, Value *RHS, + const Twine &Name, Instruction *I) const { assert(isVectorizable() && "Expected add|fadd or min/max reduction operation."); - auto *Op = createOp(Builder, Name); + auto *Op = createOp(Builder, LHS, RHS, Name); switch (Kind) { case RK_Arithmetic: propagateIRFlags(Op, I); @@ -6637,19 +6629,18 @@ Value *LHS; Value *RHS; if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(I)) { - return OperationData(cast(I)->getOpcode(), LHS, RHS, - RK_Arithmetic); + return OperationData(cast(I)->getOpcode(), RK_Arithmetic); } if (auto *Select = dyn_cast(I)) { // Look for a min/max pattern. if (m_UMin(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin); + return OperationData(Instruction::ICmp, RK_UMin); } else if (m_SMin(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData(Instruction::ICmp, LHS, RHS, RK_SMin); + return OperationData(Instruction::ICmp, RK_SMin); } else if (m_UMax(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax); + return OperationData(Instruction::ICmp, RK_UMax); } else if (m_SMax(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData(Instruction::ICmp, LHS, RHS, RK_SMax); + return OperationData(Instruction::ICmp, RK_SMax); } else { // Try harder: look for min/max pattern based on instructions producing // same values such as: select ((cmp Inst1, Inst2), Inst1, Inst2). @@ -6693,19 +6684,19 @@ case CmpInst::ICMP_ULT: case CmpInst::ICMP_ULE: - return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin); + return OperationData(Instruction::ICmp, RK_UMin); case CmpInst::ICMP_SLT: case CmpInst::ICMP_SLE: - return OperationData(Instruction::ICmp, LHS, RHS, RK_SMin); + return OperationData(Instruction::ICmp, RK_SMin); case CmpInst::ICMP_UGT: case CmpInst::ICMP_UGE: - return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax); + return OperationData(Instruction::ICmp, RK_UMax); case CmpInst::ICMP_SGT: case CmpInst::ICMP_SGE: - return OperationData(Instruction::ICmp, LHS, RHS, RK_SMax); + return OperationData(Instruction::ICmp, RK_SMax); } } } @@ -6726,13 +6717,13 @@ // r *= v1 + v2 + v3 + v4 // In such a case start looking for a tree rooted in the first '+'. if (Phi) { - if (ReductionData.getLHS() == Phi) { + if (ReductionData.getLHS(B) == Phi) { Phi = nullptr; - B = dyn_cast(ReductionData.getRHS()); + B = dyn_cast(ReductionData.getRHS(B)); ReductionData = getOperationData(B); - } else if (ReductionData.getRHS() == Phi) { + } else if (ReductionData.getRHS(B) == Phi) { Phi = nullptr; - B = dyn_cast(ReductionData.getLHS()); + B = dyn_cast(ReductionData.getLHS(B)); ReductionData = getOperationData(B); } } @@ -6984,11 +6975,8 @@ } else { // Update the final value in the reduction. Builder.SetCurrentDebugLocation(Loc); - OperationData VectReductionData(ReductionData.getOpcode(), - VectorizedTree, ReducedSubTree, - ReductionData.getKind()); - VectorizedTree = - VectReductionData.createOp(Builder, "op.rdx", ReductionOps); + VectorizedTree = ReductionData.createOp( + Builder, VectorizedTree, ReducedSubTree, "op.rdx", ReductionOps); } i += ReduxWidth; ReduxWidth = PowerOf2Floor(NumReducedVals - i); @@ -6999,19 +6987,15 @@ for (; i < NumReducedVals; ++i) { auto *I = cast(ReducedVals[i]); Builder.SetCurrentDebugLocation(I->getDebugLoc()); - OperationData VectReductionData(ReductionData.getOpcode(), - VectorizedTree, I, - ReductionData.getKind()); - VectorizedTree = VectReductionData.createOp(Builder, "", ReductionOps); + VectorizedTree = ReductionData.createOp(Builder, VectorizedTree, I, "", + ReductionOps); } for (auto &Pair : ExternallyUsedValues) { // Add each externally used value to the final reduction. for (auto *I : Pair.second) { Builder.SetCurrentDebugLocation(I->getDebugLoc()); - OperationData VectReductionData(ReductionData.getOpcode(), - VectorizedTree, Pair.first, - ReductionData.getKind()); - VectorizedTree = VectReductionData.createOp(Builder, "op.extra", I); + VectorizedTree = ReductionData.createOp(Builder, VectorizedTree, + Pair.first, "op.extra", I); } } @@ -7133,9 +7117,8 @@ Builder.CreateShuffleVector(TmpVec, LeftMask, "rdx.shuf.l"); Value *RightShuf = Builder.CreateShuffleVector(TmpVec, RightMask, "rdx.shuf.r"); - OperationData VectReductionData(ReductionData.getOpcode(), LeftShuf, - RightShuf, ReductionData.getKind()); - TmpVec = VectReductionData.createOp(Builder, "op.rdx", ReductionOps); + TmpVec = ReductionData.createOp(Builder, LeftShuf, RightShuf, "op.rdx", + ReductionOps); } // The result is in the first element of the vector.