Index: llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -6312,20 +6312,13 @@ /// Opcode of the instruction. unsigned Opcode = 0; - /// Left operand of the reduction operation. - Value *LHS = nullptr; - - /// Right operand of the reduction operation. - Value *RHS = nullptr; - /// Kind of the reduction operation. ReductionKind Kind = RK_None; /// Checks if the reduction operation can be vectorized. bool isVectorizable() const { - return LHS && RHS && - // We currently only support add/mul/logical && min/max reductions. - ((Kind == RK_Arithmetic && + // We currently only support add/mul/logical && min/max reductions. + return ((Kind == RK_Arithmetic && (Opcode == Instruction::Add || Opcode == Instruction::FAdd || Opcode == Instruction::Mul || Opcode == Instruction::FMul || Opcode == Instruction::And || Opcode == Instruction::Or || @@ -6378,8 +6371,8 @@ /// Constructor for reduction operations with opcode and its left and /// right operands. - OperationData(unsigned Opcode, Value *LHS, Value *RHS, ReductionKind Kind) - : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) { + OperationData(unsigned Opcode, ReductionKind Kind) + : Opcode(Opcode), Kind(Kind) { assert(Kind != RK_None && "One of the reduction operations is expected."); } @@ -6412,16 +6405,14 @@ /// Total number of operands in the reduction operation. unsigned getNumberOfOperands() const { - assert(Kind != RK_None && !!*this && LHS && RHS && - "Expected reduction operation."); + assert(Kind != RK_None && !!*this && "Expected reduction operation."); return isMinMax() ? 3 : 2; } /// Checks if the instruction is in basic block \p BB. /// For a min/max reduction check that both compare and select are in \p BB. bool hasSameParent(Instruction *I, BasicBlock *BB, bool IsRedOp) const { - assert(Kind != RK_None && !!*this && LHS && RHS && - "Expected reduction operation."); + assert(Kind != RK_None && !!*this && "Expected reduction operation."); if (IsRedOp && isMinMax()) { auto *Cmp = cast(cast(I)->getCondition()); return I->getParent() == BB && Cmp && Cmp->getParent() == BB; @@ -6431,8 +6422,7 @@ /// Expected number of uses for reduction operations/reduced values. bool hasRequiredNumberOfUses(Instruction *I, bool IsReductionOp) const { - assert(Kind != RK_None && !!*this && LHS && RHS && - "Expected reduction operation."); + assert(Kind != RK_None && !!*this && "Expected reduction operation."); // SelectInst must be used twice while the condition op must have single // use only. if (isMinMax()) @@ -6446,8 +6436,7 @@ /// Initializes the list of reduction operations. void initReductionOps(ReductionOpsListType &ReductionOps) { - assert(Kind != RK_None && !!*this && LHS && RHS && - "Expected reduction operation."); + assert(Kind != RK_None && !!*this && "Expected reduction operation."); if (isMinMax()) ReductionOps.assign(2, ReductionOpsType()); else @@ -6456,8 +6445,7 @@ /// Add all reduction operations for the reduction instruction \p I. void addReductionOps(Instruction *I, ReductionOpsListType &ReductionOps) { - assert(Kind != RK_None && !!*this && LHS && RHS && - "Expected reduction operation."); + assert(Kind != RK_None && !!*this && "Expected reduction operation."); if (isMinMax()) { ReductionOps[0].emplace_back(cast(I)->getCondition()); ReductionOps[1].emplace_back(I); @@ -6468,8 +6456,7 @@ /// Checks if instruction is associative and can be vectorized. bool isAssociative(Instruction *I) const { - assert(Kind != RK_None && *this && LHS && RHS && - "Expected reduction operation."); + assert(Kind != RK_None && *this && "Expected reduction operation."); switch (Kind) { case RK_Arithmetic: return I->isAssociative(); @@ -6494,15 +6481,13 @@ /// Checks if two operation data are both a reduction op or both a reduced /// value. bool operator==(const OperationData &OD) const { - assert(((Kind != OD.Kind) || ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) && + assert(((Kind != OD.Kind) || (Opcode != 0 && OD.Opcode != 0)) && "One of the comparing operations is incorrect."); - return this == &OD || (Kind == OD.Kind && Opcode == OD.Opcode); + return Kind == OD.Kind && Opcode == OD.Opcode; } bool operator!=(const OperationData &OD) const { return !(*this == OD); } void clear() { Opcode = 0; - LHS = nullptr; - RHS = nullptr; Kind = RK_None; } @@ -6514,10 +6499,15 @@ /// Get kind of reduction data. ReductionKind getKind() const { return Kind; } - Value *getLHS() const { return LHS; } - Value *getRHS() const { return RHS; } - Type *getConditionType() const { - return isMinMax() ? CmpInst::makeCmpResultType(LHS->getType()) : nullptr; + Value *getLHS(Instruction *I) const { + if (Kind == RK_None) + return nullptr; + return I->getOperand(getFirstOperandIndex()); + } + Value *getRHS(Instruction *I) const { + if (Kind == RK_None) + return nullptr; + return I->getOperand(getFirstOperandIndex() + 1); } /// Creates reduction operation with the current opcode with the IR flags @@ -6639,19 +6629,18 @@ Value *LHS; Value *RHS; if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(I)) { - return OperationData(cast(I)->getOpcode(), LHS, RHS, - RK_Arithmetic); + return OperationData(cast(I)->getOpcode(), RK_Arithmetic); } if (auto *Select = dyn_cast(I)) { // Look for a min/max pattern. if (m_UMin(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin); + return OperationData(Instruction::ICmp, RK_UMin); } else if (m_SMin(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData(Instruction::ICmp, LHS, RHS, RK_SMin); + return OperationData(Instruction::ICmp, RK_SMin); } else if (m_UMax(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax); + return OperationData(Instruction::ICmp, RK_UMax); } else if (m_SMax(m_Value(LHS), m_Value(RHS)).match(Select)) { - return OperationData(Instruction::ICmp, LHS, RHS, RK_SMax); + return OperationData(Instruction::ICmp, RK_SMax); } else { // Try harder: look for min/max pattern based on instructions producing // same values such as: select ((cmp Inst1, Inst2), Inst1, Inst2). @@ -6695,19 +6684,19 @@ case CmpInst::ICMP_ULT: case CmpInst::ICMP_ULE: - return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin); + return OperationData(Instruction::ICmp, RK_UMin); case CmpInst::ICMP_SLT: case CmpInst::ICMP_SLE: - return OperationData(Instruction::ICmp, LHS, RHS, RK_SMin); + return OperationData(Instruction::ICmp, RK_SMin); case CmpInst::ICMP_UGT: case CmpInst::ICMP_UGE: - return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax); + return OperationData(Instruction::ICmp, RK_UMax); case CmpInst::ICMP_SGT: case CmpInst::ICMP_SGE: - return OperationData(Instruction::ICmp, LHS, RHS, RK_SMax); + return OperationData(Instruction::ICmp, RK_SMax); } } } @@ -6728,13 +6717,13 @@ // r *= v1 + v2 + v3 + v4 // In such a case start looking for a tree rooted in the first '+'. if (Phi) { - if (ReductionData.getLHS() == Phi) { + if (ReductionData.getLHS(B) == Phi) { Phi = nullptr; - B = dyn_cast(ReductionData.getRHS()); + B = dyn_cast(ReductionData.getRHS(B)); ReductionData = getOperationData(B); - } else if (ReductionData.getRHS() == Phi) { + } else if (ReductionData.getRHS(B) == Phi) { Phi = nullptr; - B = dyn_cast(ReductionData.getLHS()); + B = dyn_cast(ReductionData.getLHS(B)); ReductionData = getOperationData(B); } }