Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -626,7 +626,8 @@ /// Split: /// (v0, v1, v2, v3) /// ((v0+v2), (v1+v3), undef, undef) - int getReductionCost(unsigned Opcode, Type *Ty, bool IsPairwiseForm) const; + int getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwiseForm) const; /// \returns The cost of Intrinsic instructions. Types analysis only. int getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, @@ -825,8 +826,8 @@ ArrayRef Indices, unsigned Alignment, unsigned AddressSpace) = 0; - virtual int getReductionCost(unsigned Opcode, Type *Ty, - bool IsPairwiseForm) = 0; + virtual int getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwiseForm) = 0; virtual int getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, ArrayRef Tys, FastMathFlags FMF) = 0; @@ -1081,9 +1082,9 @@ return Impl.getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices, Alignment, AddressSpace); } - int getReductionCost(unsigned Opcode, Type *Ty, - bool IsPairwiseForm) override { - return Impl.getReductionCost(Opcode, Ty, IsPairwiseForm); + int getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwiseForm) override { + return Impl.getArithmeticReductionCost(Opcode, Ty, IsPairwiseForm); } int getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, ArrayRef Tys, FastMathFlags FMF) override { Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -388,7 +388,7 @@ return 0; } - unsigned getReductionCost(unsigned, Type *, bool) { return 1; } + unsigned getArithmeticReductionCost(unsigned, Type *, bool) { return 1; } unsigned getCostOfKeepingLiveOverCall(ArrayRef Tys) { return 0; } Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -956,46 +956,47 @@ return 0; } - unsigned getReductionCost(unsigned Opcode, Type *Ty, bool IsPairwise) { + /// Try to calculate arithmetic and shuffle op costs for reduction operations. + /// We're assuming that reduction operation are performing the following way: + /// 1. Non-pairwise reduction + /// %val1 = shufflevector %val, %undef, + /// + /// \----------------v-------------/ \----------v------------/ + /// n/2 elements n/2 elements + /// %red1 = op %val, val1 + /// After this operation we have a vector %red1 where only the first n/2 + /// elements are meaningful, the second n/2 elements are undefined and can be + /// dropped. All other operations are actually working with the vector of + /// length n/2, not n, though the real vector length is still n. + /// %val2 = shufflevector %red1, %undef, + /// + /// \----------------v-------------/ \----------v------------/ + /// n/4 elements 3*n/4 elements + /// %red2 = op %red1, val2 - working with the vector of + /// length n/2, the resulting vector has length n/4 etc. + /// 2. Pairwise reduction: + /// Everything is the same except for an additional shuffle operation which + /// is used to produce operands for pairwise kind of reductions. + /// %val1 = shufflevector %val, %undef, + /// + /// \-------------v----------/ \----------v------------/ + /// n/2 elements n/2 elements + /// %val2 = shufflevector %val, %undef, + /// + /// \-------------v----------/ \----------v------------/ + /// n/2 elements n/2 elements + /// %red1 = op %val1, val2 + /// Again, the operation is performed on vector, but the resulting + /// vector %red1 is vector. + /// + /// The cost model should take into account that the actual length of the + /// vector is reduced on each iteration. + unsigned getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwise) { assert(Ty->isVectorTy() && "Expect a vector type"); Type *ScalarTy = Ty->getVectorElementType(); unsigned NumVecElts = Ty->getVectorNumElements(); unsigned NumReduxLevels = Log2_32(NumVecElts); - // Try to calculate arithmetic and shuffle op costs for reduction operations. - // We're assuming that reduction operation are performing the following way: - // 1. Non-pairwise reduction - // %val1 = shufflevector %val, %undef, - // - // \----------------v-------------/ \----------v------------/ - // n/2 elements n/2 elements - // %red1 = op %val, val1 - // After this operation we have a vector %red1 with only maningfull the - // first n/2 elements, the second n/2 elements are undefined and can be - // dropped. All other operations are actually working with the vector of - // length n/2, not n. though the real vector length is still n. - // %val2 = shufflevector %red1, %undef, - // - // \----------------v-------------/ \----------v------------/ - // n/4 elements 3*n/4 elements - // %red2 = op %red1, val2 - working with the vector of - // length n/2, the resulting vector has length n/4 etc. - // 2. Pairwise reduction: - // Everything is the same except for an additional shuffle operation which - // is used to produce operands for pairwise kind of reductions. - // %val1 = shufflevector %val, %undef, - // - // \-------------v----------/ \----------v------------/ - // n/2 elements n/2 elements - // %val2 = shufflevector %val, %undef, - // - // \-------------v----------/ \----------v------------/ - // n/2 elements n/2 elements - // %red1 = op %val1, val2 - // Again, the operation is performed on vector, but the resulting - // vector %red1 is vector. - // - // The cost model should take into account that the actual length of the - // vector is reduced on each iteration. unsigned ArithCost = 0; unsigned ShuffleCost = 0; auto *ConcreteTTI = static_cast(this); Index: include/llvm/Transforms/Vectorize/SLPVectorizer.h =================================================================== --- include/llvm/Transforms/Vectorize/SLPVectorizer.h +++ include/llvm/Transforms/Vectorize/SLPVectorizer.h @@ -82,8 +82,8 @@ ArrayRef BuildVector = None, bool AllowReorder = false); - /// \brief Try to vectorize a chain that may start at the operands of \V; - bool tryToVectorize(BinaryOperator *V, slpvectorizer::BoUpSLP &R); + /// \brief Try to vectorize a chain that may start at the operands of \p I; + bool tryToVectorize(Instruction *I, slpvectorizer::BoUpSLP &R); /// \brief Vectorize the store instructions collected in Stores. bool vectorizeStoreChains(slpvectorizer::BoUpSLP &R); Index: lib/Analysis/CostModel.cpp =================================================================== --- lib/Analysis/CostModel.cpp +++ lib/Analysis/CostModel.cpp @@ -24,12 +24,14 @@ #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; +using namespace PatternMatch; #define CM_NAME "cost-model" #define DEBUG_TYPE CM_NAME @@ -183,27 +185,47 @@ return Mask == ActualMask; } -static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp, - unsigned Level, unsigned NumLevels) { +namespace { +/// Conatins opcode + LHS/RHS parts of the reduction operations. +struct ReductionData { + explicit ReductionData() = default; + ReductionData(unsigned Opcode, Value *LHS, Value *RHS) + : Opcode(Opcode), LHS(LHS), RHS(RHS) {} + unsigned Opcode = 0; + Value *LHS = nullptr; + Value *RHS = nullptr; +}; +} // namespace + +static ReductionData getReductionData(Instruction *I) { + Value *L = nullptr; + Value *R = nullptr; + if (m_BinOp(m_Value(L), m_Value(R)).match(I)) + return {I->getOpcode(), L, R}; + return ReductionData(); +} + +static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level, + unsigned NumLevels) { // Match one level of pairwise operations. // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, // <4 x i32> // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, // <4 x i32> // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 - if (BinOp == nullptr) + if (!I) return false; - assert(BinOp->getType()->isVectorTy() && "Expecting a vector type"); + assert(I->getType()->isVectorTy() && "Expecting a vector type"); - unsigned Opcode = BinOp->getOpcode(); - Value *L = BinOp->getOperand(0); - Value *R = BinOp->getOperand(1); + ReductionData RD = getReductionData(I); + if (!RD.Opcode) + return false; - ShuffleVectorInst *LS = dyn_cast(L); + ShuffleVectorInst *LS = dyn_cast(RD.LHS); if (!LS && Level) return false; - ShuffleVectorInst *RS = dyn_cast(R); + ShuffleVectorInst *RS = dyn_cast(RD.RHS); if (!RS && Level) return false; @@ -228,31 +250,27 @@ // Example: // %NextLevelOpL = shufflevector %R, <1, undef ...> // %BinOp = fadd %NextLevelOpL, %R - if (NextLevelOpL && NextLevelOpL != R) + if (NextLevelOpL && NextLevelOpL != RD.RHS) return false; - else if (NextLevelOpR && NextLevelOpR != L) + else if (NextLevelOpR && NextLevelOpR != RD.LHS) return false; - NextLevelOp = NextLevelOpL ? R : L; + NextLevelOp = NextLevelOpL ? RD.RHS : RD.LHS; } else return false; // Check that the next levels binary operation exists and matches with the // current one. - BinaryOperator *NextLevelBinOp = nullptr; - if (Level + 1 != NumLevels) { - if (!(NextLevelBinOp = dyn_cast(NextLevelOp))) - return false; - else if (NextLevelBinOp->getOpcode() != Opcode) + if (Level + 1 != NumLevels) + if (RD.Opcode != getReductionData(cast(NextLevelOp)).Opcode) return false; - } // Shuffle mask for pairwise operation must match. - if (matchPairwiseShuffleMask(LS, true, Level)) { - if (!matchPairwiseShuffleMask(RS, false, Level)) + if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) { + if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level)) return false; - } else if (matchPairwiseShuffleMask(RS, true, Level)) { - if (!matchPairwiseShuffleMask(LS, false, Level)) + } else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) { + if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level)) return false; } else return false; @@ -261,7 +279,8 @@ return true; // Match next level. - return matchPairwiseReductionAtLevel(NextLevelBinOp, Level, NumLevels); + return matchPairwiseReductionAtLevel(cast(NextLevelOp), Level, + NumLevels); } static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, @@ -277,11 +296,14 @@ if (Idx != 0) return false; - BinaryOperator *RdxStart = dyn_cast(ReduxRoot->getOperand(0)); + auto *RdxStart = dyn_cast(ReduxRoot->getOperand(0)); if (!RdxStart) return false; + ReductionData RD = getReductionData(RdxStart); + if (RD.Opcode == 0) + return false; - Type *VecTy = ReduxRoot->getOperand(0)->getType(); + Type *VecTy = RdxStart->getType(); unsigned NumVecElems = VecTy->getVectorNumElements(); if (!isPowerOf2_32(NumVecElems)) return false; @@ -307,17 +329,14 @@ if (!matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems))) return false; - Opcode = RdxStart->getOpcode(); + Opcode = RD.Opcode; Ty = VecTy; return true; } static std::pair -getShuffleAndOtherOprd(BinaryOperator *B) { - - Value *L = B->getOperand(0); - Value *R = B->getOperand(1); +getShuffleAndOtherOprd(Value *L, Value *R) { ShuffleVectorInst *S = nullptr; if ((S = dyn_cast(L))) @@ -340,10 +359,12 @@ if (Idx != 0) return false; - BinaryOperator *RdxStart = dyn_cast(ReduxRoot->getOperand(0)); + auto *RdxStart = dyn_cast(ReduxRoot->getOperand(0)); if (!RdxStart) return false; - unsigned RdxOpcode = RdxStart->getOpcode(); + ReductionData RD = getReductionData(RdxStart); + if (RD.Opcode == 0) + return false; Type *VecTy = ReduxRoot->getOperand(0)->getType(); unsigned NumVecElems = VecTy->getVectorNumElements(); @@ -367,15 +388,17 @@ unsigned NumVecElemsRemain = NumVecElems; while (NumVecElemsRemain - 1) { // Check for the right reduction operation. - BinaryOperator *BinOp; - if (!(BinOp = dyn_cast(RdxOp))) + Instruction *Op = dyn_cast(RdxOp); + if (!Op) return false; - if (BinOp->getOpcode() != RdxOpcode) + ReductionData RDLevel = getReductionData(Op); + if (RDLevel.Opcode != RD.Opcode) return false; Value *NextRdxOp; ShuffleVectorInst *Shuffle; - std::tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(BinOp); + std::tie(NextRdxOp, Shuffle) = + getShuffleAndOtherOprd(RDLevel.LHS, RDLevel.RHS); // Check the current reduction operation and the shuffle use the same value. if (Shuffle == nullptr) @@ -398,7 +421,7 @@ MaskStart *= 2; } - Opcode = RdxOpcode; + Opcode = RD.Opcode; Ty = VecTy; return true; } @@ -495,10 +518,14 @@ unsigned ReduxOpCode; Type *ReduxType; - if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) - return TTI->getReductionCost(ReduxOpCode, ReduxType, false); - else if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) - return TTI->getReductionCost(ReduxOpCode, ReduxType, true); + if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) { + return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, + /*IsPairwiseForm=*/false); + } + if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) { + return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, + /*IsPairwiseForm=*/true); + } return TTI->getVectorInstrCost(I->getOpcode(), EEI->getOperand(0)->getType(), Idx); Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -412,9 +412,9 @@ return Cost; } -int TargetTransformInfo::getReductionCost(unsigned Opcode, Type *Ty, - bool IsPairwiseForm) const { - int Cost = TTIImpl->getReductionCost(Opcode, Ty, IsPairwiseForm); +int TargetTransformInfo::getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwiseForm) const { + int Cost = TTIImpl->getArithmeticReductionCost(Opcode, Ty, IsPairwiseForm); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } Index: lib/Target/X86/X86TargetTransformInfo.h =================================================================== --- lib/Target/X86/X86TargetTransformInfo.h +++ lib/Target/X86/X86TargetTransformInfo.h @@ -78,7 +78,8 @@ int getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, ArrayRef Args, FastMathFlags FMF); - int getReductionCost(unsigned Opcode, Type *Ty, bool IsPairwiseForm); + int getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwiseForm); int getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef Indices, Index: lib/Target/X86/X86TargetTransformInfo.cpp =================================================================== --- lib/Target/X86/X86TargetTransformInfo.cpp +++ lib/Target/X86/X86TargetTransformInfo.cpp @@ -1710,8 +1710,8 @@ return BaseT::getAddressComputationCost(Ty, SE, Ptr); } -int X86TTIImpl::getReductionCost(unsigned Opcode, Type *ValTy, - bool IsPairwise) { +int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy, + bool IsPairwise) { std::pair LT = TLI->getTypeLegalizationCost(DL, ValTy); @@ -1779,7 +1779,7 @@ return LT.first * Entry->Cost; } - return BaseT::getReductionCost(Opcode, ValTy, IsPairwise); + return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwise); } /// \brief Calculate the cost of materializing a 64-bit value. This helper Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/NoFolder.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/IR/Verifier.h" @@ -45,6 +46,7 @@ #include using namespace llvm; +using namespace llvm::PatternMatch; using namespace slpvectorizer; #define SV_NAME "slp-vectorizer" @@ -4149,15 +4151,18 @@ return Changed; } -bool SLPVectorizerPass::tryToVectorize(BinaryOperator *V, BoUpSLP &R) { - if (!V) +bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) { + if (!I) return false; - Value *P = V->getParent(); + if (!isa(I)) + return false; + + Value *P = I->getParent(); // Vectorize in current basic block only. - auto *Op0 = dyn_cast(V->getOperand(0)); - auto *Op1 = dyn_cast(V->getOperand(1)); + auto *Op0 = dyn_cast(I->getOperand(0)); + auto *Op1 = dyn_cast(I->getOperand(1)); if (!Op0 || !Op1 || Op0->getParent() != P || Op1->getParent() != P) return false; @@ -4253,7 +4258,88 @@ // Use map vector to make stable output. MapVector ExtraArgs; - BinaryOperator *ReductionRoot = nullptr; + /// Contains info about operation, like its opcode, left and right operands + + /// validity. + struct OperationData { + /// true if the operation is a reduced value, false if reduction operation. + bool IsReducedValue = false; + /// Opcode of the instruction. + unsigned Opcode = 0; + /// Left operand of the reduction operation. + Value *LHS = nullptr; + /// Right operand of the reduction operation. + Value *RHS = nullptr; + + public: + explicit OperationData() = default; + /// Construction for reduced values. They are identified by opcode only and + /// don't have associated LHS/RHS values. + explicit OperationData(Value *V) : IsReducedValue(true) { + if (auto *I = dyn_cast(V)) + Opcode = I->getOpcode(); + } + /// Constructor for binary reduction operations with opcode and its left and + /// right operands. + OperationData(BinaryOperator::BinaryOps Opcode, Value *LHS, Value *RHS) + : IsReducedValue(false), Opcode(Opcode), LHS(LHS), RHS(RHS) {} + operator bool() const { return Opcode; } + /// Get the index of the first operand. + unsigned getFirstOperandIndex() const { + assert(Opcode); + return 0; + } + /// Total number of operands in the reduction operation. + unsigned getNumberOfOperands() const { + assert(!IsReducedValue && Opcode && LHS && RHS); + return 2; + } + /// Expected number of uses for reduction operations/reduced values. + unsigned getRequiredNumberOfUses() const { + assert(!IsReducedValue && Opcode && LHS && RHS); + return 1; + } + /// Checks if instruction is associative and can be vectorized. + bool isAssociative(Instruction *I) const { + assert(!IsReducedValue && Opcode && LHS && RHS); + return I->isAssociative(); + } + /// Checks if the reduction operation can be vectorized. + bool isVectorizable() const { + return LHS && RHS && + // We currently only support adds. + (Opcode == Instruction::Add || Opcode == Instruction::FAdd); + } + /// Checks if two operation data are both a reduction op or both a reduced + /// value. + bool operator==(const OperationData &OD) { + return this == &OD || + (IsReducedValue == OD.IsReducedValue && (!LHS == !OD.LHS) && + (!RHS == !OD.RHS) && Opcode == OD.Opcode); + } + bool operator!=(const OperationData &OD) { return !(*this == OD); } + void clear() { + IsReducedValue = false; + Opcode = 0; + LHS = nullptr; + RHS = nullptr; + } + /// Get opcode of the reduction operation. + unsigned getOpcode() const { + assert(isVectorizable()); + return Opcode; + } + Value *getLHS() const { return LHS; } + Value *getRHS() const { return RHS; } + /// Creates reuction operation with the current opcode. + Value *createOp(IRBuilder<> &Builder, Value *L, Value *R, + const Twine &Name = "") const { + assert(!IsReducedValue && + (Opcode == Instruction::FAdd || Opcode == Instruction::Add)); + return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, L, R, Name); + } + }; + + Instruction *ReductionRoot = nullptr; // After successfull horizontal reduction vectorization attempt for PHI node // vectorizer tries to update root binary op by combining vectorized tree and // the ReductionPHI node. But during vectorization this ReductionPHI can be @@ -4263,10 +4349,10 @@ // is destroyed" crash upon PHI node deletion. WeakVH ReductionPHI; - /// The opcode of the reduction. - Instruction::BinaryOps ReductionOpcode = Instruction::BinaryOpsEnd; - /// The opcode of the values we perform a reduction on. - unsigned ReducedValueOpcode = 0; + /// The operation data of the reduction operation. + OperationData ReductionData; + /// The operation data of the values we perform a reduction on. + OperationData ReducedValueData; /// Should we model this reduction as a pairwise reduction tree or a tree that /// splits the vector in halves and adds those halves. bool IsPairwiseReduction = false; @@ -4291,56 +4377,66 @@ } } + static OperationData getOperationData(Value *V) { + if (!V) + return OperationData(); + + Value *LHS; + Value *RHS; + if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V)) + return {cast(V)->getOpcode(), LHS, RHS}; + return OperationData(V); + } + public: HorizontalReduction() = default; /// \brief Try to find a reduction tree. - bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) { + bool matchAssociativeReduction(PHINode *Phi, Instruction *B) { assert((!Phi || is_contained(Phi->operands(), B)) && "Thi phi needs to use the binary operator"); + ReductionData = getOperationData(B); + // We could have a initial reductions that is not an add. // r *= v1 + v2 + v3 + v4 // In such a case start looking for a tree rooted in the first '+'. if (Phi) { - if (B->getOperand(0) == Phi) { + if (ReductionData.getLHS() == Phi) { Phi = nullptr; - B = dyn_cast(B->getOperand(1)); - } else if (B->getOperand(1) == Phi) { + B = dyn_cast(ReductionData.getRHS()); + ReductionData = getOperationData(B); + } else if (ReductionData.getRHS() == Phi) { Phi = nullptr; - B = dyn_cast(B->getOperand(0)); + B = dyn_cast(ReductionData.getLHS()); + ReductionData = getOperationData(B); } } - if (!B) + if (!ReductionData.isVectorizable() || !ReductionData.isAssociative(B)) return false; Type *Ty = B->getType(); if (!isValidElementType(Ty)) return false; - ReductionOpcode = B->getOpcode(); - ReducedValueOpcode = 0; + ReducedValueData.clear(); ReductionRoot = B; ReductionPHI = Phi; - // We currently only support adds. - if ((ReductionOpcode != Instruction::Add && - ReductionOpcode != Instruction::FAdd) || - !B->isAssociative()) - return false; - // Post order traverse the reduction tree starting at B. We only handle true - // trees containing only binary operators or selects. + // trees containing only binary operators. SmallVector, 32> Stack; - Stack.push_back(std::make_pair(B, 0)); + Stack.push_back(std::make_pair(B, ReductionData.getFirstOperandIndex())); + const unsigned NUses = ReductionData.getRequiredNumberOfUses(); while (!Stack.empty()) { Instruction *TreeN = Stack.back().first; unsigned EdgeToVist = Stack.back().second++; - bool IsReducedValue = TreeN->getOpcode() != ReductionOpcode; + OperationData OpData = getOperationData(TreeN); + bool IsReducedValue = OpData != ReductionData; // Postorder vist. - if (EdgeToVist == 2 || IsReducedValue) { + if (IsReducedValue || EdgeToVist == OpData.getNumberOfOperands()) { if (IsReducedValue) ReducedVals.push_back(TreeN); else { @@ -4369,12 +4465,13 @@ Value *NextV = TreeN->getOperand(EdgeToVist); if (NextV != Phi) { auto *I = dyn_cast(NextV); + OpData = getOperationData(I); // Continue analysis if the next operand is a reduction operation or // (possibly) a reduced value. If the reduced value opcode is not set, // the first met operation != reduction operation is considered as the // reduced value class. - if (I && (!ReducedValueOpcode || I->getOpcode() == ReducedValueOpcode || - I->getOpcode() == ReductionOpcode)) { + if (I && (!ReducedValueData || OpData == ReducedValueData || + OpData == ReductionData)) { // Only handle trees in the current basic block. if (I->getParent() != B->getParent()) { // I is an extra argument for TreeN (its parent operation). @@ -4382,32 +4479,32 @@ continue; } - // Each tree node needs to have one user except for the ultimate - // reduction. - if (!I->hasOneUse() && I != B) { + // Each tree node needs to have minimal number of users except for the + // ultimate reduction. + if (!I->hasNUses(NUses) && I != B) { // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; } - if (I->getOpcode() == ReductionOpcode) { + if (OpData == ReductionData) { // We need to be able to reassociate the reduction operations. - if (!I->isAssociative()) { + if (!OpData.isAssociative(I)) { // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; } - } else if (ReducedValueOpcode && - ReducedValueOpcode != I->getOpcode()) { + } else if (ReducedValueData && + ReducedValueData != OpData) { // Make sure that the opcodes of the operations that we are going to // reduce match. // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; - } else if (!ReducedValueOpcode) - ReducedValueOpcode = I->getOpcode(); + } else if (!ReducedValueData) + ReducedValueData = OpData; - Stack.push_back(std::make_pair(I, 0)); + Stack.push_back(std::make_pair(I, OpData.getFirstOperandIndex())); continue; } // NextV is an extra argument for TreeN (its parent operation). @@ -4472,8 +4569,8 @@ emitReduction(VectorizedRoot, Builder, ReduxWidth); if (VectorizedTree) { Builder.SetCurrentDebugLocation(Loc); - VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, - ReducedSubTree, "bin.rdx"); + VectorizedTree = ReductionData.createOp(Builder, VectorizedTree, + ReducedSubTree, "bin.rdx"); } else VectorizedTree = ReducedSubTree; i += ReduxWidth; @@ -4486,12 +4583,12 @@ auto *I = cast(ReducedVals[i]); Builder.SetCurrentDebugLocation(I->getDebugLoc()); VectorizedTree = - Builder.CreateBinOp(ReductionOpcode, VectorizedTree, I); + ReductionData.createOp(Builder, VectorizedTree, I); } for (auto &Pair : ExternallyUsedValues) { Builder.SetCurrentDebugLocation(Pair.second); - VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, - Pair.first, "bin.extra"); + VectorizedTree = ReductionData.createOp(Builder, VectorizedTree, + Pair.first, "bin.extra"); } // Update users. if (ReductionPHI && !isa(ReductionPHI)) { @@ -4515,15 +4612,19 @@ Type *ScalarTy = FirstReducedVal->getType(); Type *VecTy = VectorType::get(ScalarTy, ReduxWidth); - int PairwiseRdxCost = TTI->getReductionCost(ReductionOpcode, VecTy, true); - int SplittingRdxCost = TTI->getReductionCost(ReductionOpcode, VecTy, false); + int PairwiseRdxCost = + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/true); + int SplittingRdxCost = + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/false); IsPairwiseReduction = PairwiseRdxCost < SplittingRdxCost; int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost; int ScalarReduxCost = (ReduxWidth - 1) * - TTI->getArithmeticInstrCost(ReductionOpcode, ScalarTy); + TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy); DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost << " for reduction that starts with " << *FirstReducedVal @@ -4554,14 +4655,14 @@ Value *RightShuf = Builder.CreateShuffleVector( TmpVec, UndefValue::get(TmpVec->getType()), (RightMask), "rdx.shuf.r"); - TmpVec = Builder.CreateBinOp(ReductionOpcode, LeftShuf, RightShuf, - "bin.rdx"); + TmpVec = + ReductionData.createOp(Builder, LeftShuf, RightShuf, "bin.rdx"); } else { Value *UpperHalf = createRdxShuffleMask(ReduxWidth, i, false, false, Builder); Value *Shuf = Builder.CreateShuffleVector( TmpVec, UndefValue::get(TmpVec->getType()), UpperHalf, "rdx.shuf"); - TmpVec = Builder.CreateBinOp(ReductionOpcode, TmpVec, Shuf, "bin.rdx"); + TmpVec = ReductionData.createOp(Builder, TmpVec, Shuf, "bin.rdx"); } } @@ -4760,9 +4861,10 @@ } if (Stack.back().isInitial()) { Stack.back().clearInitial(); - if (auto *BI = dyn_cast(Inst)) { + auto *BI = dyn_cast(Inst); + if (BI) { HorizontalReduction HorRdx; - if (HorRdx.matchAssociativeReduction(P, BI)) { + if (HorRdx.matchAssociativeReduction(P, Inst)) { if (HorRdx.tryToReduce(R, TTI)) { Res = true; P = nullptr;