Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -626,7 +626,8 @@ /// Split: /// (v0, v1, v2, v3) /// ((v0+v2), (v1+v3), undef, undef) - int getReductionCost(unsigned Opcode, Type *Ty, bool IsPairwiseForm) const; + int getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwiseForm) const; /// \returns The cost of Intrinsic instructions. Types analysis only. int getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, @@ -825,8 +826,8 @@ ArrayRef Indices, unsigned Alignment, unsigned AddressSpace) = 0; - virtual int getReductionCost(unsigned Opcode, Type *Ty, - bool IsPairwiseForm) = 0; + virtual int getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwiseForm) = 0; virtual int getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, ArrayRef Tys, FastMathFlags FMF) = 0; @@ -1081,9 +1082,9 @@ return Impl.getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices, Alignment, AddressSpace); } - int getReductionCost(unsigned Opcode, Type *Ty, - bool IsPairwiseForm) override { - return Impl.getReductionCost(Opcode, Ty, IsPairwiseForm); + int getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwiseForm) override { + return Impl.getArithmeticReductionCost(Opcode, Ty, IsPairwiseForm); } int getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, ArrayRef Tys, FastMathFlags FMF) override { Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -388,7 +388,7 @@ return 0; } - unsigned getReductionCost(unsigned, Type *, bool) { return 1; } + unsigned getArithmeticReductionCost(unsigned, Type *, bool) { return 1; } unsigned getCostOfKeepingLiveOverCall(ArrayRef Tys) { return 0; } Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -956,46 +956,47 @@ return 0; } - unsigned getReductionCost(unsigned Opcode, Type *Ty, bool IsPairwise) { + /// Try to calculate arithmetic and shuffle op costs for reduction operations. + /// We're assuming that reduction operation are performing the following way: + /// 1. Non-pairwise reduction + /// %val1 = shufflevector %val, %undef, + /// + /// \----------------v-------------/ \----------v------------/ + /// n/2 elements n/2 elements + /// %red1 = op %val, val1 + /// After this operation we have a vector %red1 where only the first n/2 + /// elements are meaningful, the second n/2 elements are undefined and can be + /// dropped. All other operations are actually working with the vector of + /// length n/2, not n, though the real vector length is still n. + /// %val2 = shufflevector %red1, %undef, + /// + /// \----------------v-------------/ \----------v------------/ + /// n/4 elements 3*n/4 elements + /// %red2 = op %red1, val2 - working with the vector of + /// length n/2, the resulting vector has length n/4 etc. + /// 2. Pairwise reduction: + /// Everything is the same except for an additional shuffle operation which + /// is used to produce operands for pairwise kind of reductions. + /// %val1 = shufflevector %val, %undef, + /// + /// \-------------v----------/ \----------v------------/ + /// n/2 elements n/2 elements + /// %val2 = shufflevector %val, %undef, + /// + /// \-------------v----------/ \----------v------------/ + /// n/2 elements n/2 elements + /// %red1 = op %val1, val2 + /// Again, the operation is performed on vector, but the resulting + /// vector %red1 is vector. + /// + /// The cost model should take into account that the actual length of the + /// vector is reduced on each iteration. + unsigned getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwise) { assert(Ty->isVectorTy() && "Expect a vector type"); Type *ScalarTy = Ty->getVectorElementType(); unsigned NumVecElts = Ty->getVectorNumElements(); unsigned NumReduxLevels = Log2_32(NumVecElts); - // Try to calculate arithmetic and shuffle op costs for reduction operations. - // We're assuming that reduction operation are performing the following way: - // 1. Non-pairwise reduction - // %val1 = shufflevector %val, %undef, - // - // \----------------v-------------/ \----------v------------/ - // n/2 elements n/2 elements - // %red1 = op %val, val1 - // After this operation we have a vector %red1 with only maningfull the - // first n/2 elements, the second n/2 elements are undefined and can be - // dropped. All other operations are actually working with the vector of - // length n/2, not n. though the real vector length is still n. - // %val2 = shufflevector %red1, %undef, - // - // \----------------v-------------/ \----------v------------/ - // n/4 elements 3*n/4 elements - // %red2 = op %red1, val2 - working with the vector of - // length n/2, the resulting vector has length n/4 etc. - // 2. Pairwise reduction: - // Everything is the same except for an additional shuffle operation which - // is used to produce operands for pairwise kind of reductions. - // %val1 = shufflevector %val, %undef, - // - // \-------------v----------/ \----------v------------/ - // n/2 elements n/2 elements - // %val2 = shufflevector %val, %undef, - // - // \-------------v----------/ \----------v------------/ - // n/2 elements n/2 elements - // %red1 = op %val1, val2 - // Again, the operation is performed on vector, but the resulting - // vector %red1 is vector. - // - // The cost model should take into account that the actual length of the - // vector is reduced on each iteration. unsigned ArithCost = 0; unsigned ShuffleCost = 0; auto *ConcreteTTI = static_cast(this); Index: include/llvm/Transforms/Vectorize/SLPVectorizer.h =================================================================== --- include/llvm/Transforms/Vectorize/SLPVectorizer.h +++ include/llvm/Transforms/Vectorize/SLPVectorizer.h @@ -83,7 +83,7 @@ bool AllowReorder = false); /// \brief Try to vectorize a chain that may start at the operands of \V; - bool tryToVectorize(BinaryOperator *V, slpvectorizer::BoUpSLP &R); + bool tryToVectorize(Instruction *I, slpvectorizer::BoUpSLP &R); /// \brief Vectorize the store instructions collected in Stores. bool vectorizeStoreChains(slpvectorizer::BoUpSLP &R); Index: lib/Analysis/CostModel.cpp =================================================================== --- lib/Analysis/CostModel.cpp +++ lib/Analysis/CostModel.cpp @@ -24,12 +24,14 @@ #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; +using namespace PatternMatch; #define CM_NAME "cost-model" #define DEBUG_TYPE CM_NAME @@ -183,22 +185,32 @@ return Mask == ActualMask; } -static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp, - unsigned Level, unsigned NumLevels) { +static unsigned getReductionOpcode(Value *V, Value *&L, Value *&R) { + L = nullptr; + R = nullptr; + if (m_BinOp(m_Value(L), m_Value(R)).match(V)) + return cast(V)->getOpcode(); + return 0; +} + +static bool matchPairwiseReductionAtLevel(Value *V, unsigned Level, + unsigned NumLevels) { // Match one level of pairwise operations. // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, // <4 x i32> // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, // <4 x i32> // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 - if (BinOp == nullptr) + if (!V) return false; - assert(BinOp->getType()->isVectorTy() && "Expecting a vector type"); + assert(V->getType()->isVectorTy() && "Expecting a vector type"); - unsigned Opcode = BinOp->getOpcode(); - Value *L = BinOp->getOperand(0); - Value *R = BinOp->getOperand(1); + Value *L; + Value *R; + unsigned Opcode = getReductionOpcode(V, L, R); + if (!Opcode) + return false; ShuffleVectorInst *LS = dyn_cast(L); if (!LS && Level) @@ -239,20 +251,16 @@ // Check that the next levels binary operation exists and matches with the // current one. - BinaryOperator *NextLevelBinOp = nullptr; - if (Level + 1 != NumLevels) { - if (!(NextLevelBinOp = dyn_cast(NextLevelOp))) + if (Level + 1 != NumLevels) + if (Opcode != getReductionOpcode(NextLevelOp, L, R)) return false; - else if (NextLevelBinOp->getOpcode() != Opcode) - return false; - } // Shuffle mask for pairwise operation must match. - if (matchPairwiseShuffleMask(LS, true, Level)) { - if (!matchPairwiseShuffleMask(RS, false, Level)) + if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) { + if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level)) return false; - } else if (matchPairwiseShuffleMask(RS, true, Level)) { - if (!matchPairwiseShuffleMask(LS, false, Level)) + } else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) { + if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level)) return false; } else return false; @@ -261,7 +269,7 @@ return true; // Match next level. - return matchPairwiseReductionAtLevel(NextLevelBinOp, Level, NumLevels); + return matchPairwiseReductionAtLevel(NextLevelOp, Level, NumLevels); } static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, @@ -277,11 +285,14 @@ if (Idx != 0) return false; - BinaryOperator *RdxStart = dyn_cast(ReduxRoot->getOperand(0)); - if (!RdxStart) + Value *L; + Value *R; + Value *RdxStart = ReduxRoot->getOperand(0); + unsigned RdxOpcode = getReductionOpcode(RdxStart, L, R); + if (RdxOpcode == 0) return false; - Type *VecTy = ReduxRoot->getOperand(0)->getType(); + Type *VecTy = RdxStart->getType(); unsigned NumVecElems = VecTy->getVectorNumElements(); if (!isPowerOf2_32(NumVecElems)) return false; @@ -307,17 +318,14 @@ if (!matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems))) return false; - Opcode = RdxStart->getOpcode(); + Opcode = RdxOpcode; Ty = VecTy; return true; } static std::pair -getShuffleAndOtherOprd(BinaryOperator *B) { - - Value *L = B->getOperand(0); - Value *R = B->getOperand(1); +getShuffleAndOtherOprd(Value *L, Value *R) { ShuffleVectorInst *S = nullptr; if ((S = dyn_cast(L))) @@ -340,10 +348,12 @@ if (Idx != 0) return false; - BinaryOperator *RdxStart = dyn_cast(ReduxRoot->getOperand(0)); - if (!RdxStart) + Value *L; + Value *R; + Value *RdxStart = ReduxRoot->getOperand(0); + unsigned RdxOpcode = getReductionOpcode(RdxStart, L, R); + if (RdxOpcode == 0) return false; - unsigned RdxOpcode = RdxStart->getOpcode(); Type *VecTy = ReduxRoot->getOperand(0)->getType(); unsigned NumVecElems = VecTy->getVectorNumElements(); @@ -367,15 +377,13 @@ unsigned NumVecElemsRemain = NumVecElems; while (NumVecElemsRemain - 1) { // Check for the right reduction operation. - BinaryOperator *BinOp; - if (!(BinOp = dyn_cast(RdxOp))) - return false; - if (BinOp->getOpcode() != RdxOpcode) + Value *Op = RdxOp; + if (getReductionOpcode(Op, L, R) != RdxOpcode) return false; Value *NextRdxOp; ShuffleVectorInst *Shuffle; - std::tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(BinOp); + std::tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(L, R); // Check the current reduction operation and the shuffle use the same value. if (Shuffle == nullptr) @@ -495,10 +503,14 @@ unsigned ReduxOpCode; Type *ReduxType; - if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) - return TTI->getReductionCost(ReduxOpCode, ReduxType, false); - else if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) - return TTI->getReductionCost(ReduxOpCode, ReduxType, true); + if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) { + return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, + /*IsPairwiseForm=*/false); + } + if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) { + return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, + /*IsPairwiseForm=*/true); + } return TTI->getVectorInstrCost(I->getOpcode(), EEI->getOperand(0)->getType(), Idx); Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -412,9 +412,9 @@ return Cost; } -int TargetTransformInfo::getReductionCost(unsigned Opcode, Type *Ty, - bool IsPairwiseForm) const { - int Cost = TTIImpl->getReductionCost(Opcode, Ty, IsPairwiseForm); +int TargetTransformInfo::getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwiseForm) const { + int Cost = TTIImpl->getArithmeticReductionCost(Opcode, Ty, IsPairwiseForm); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } Index: lib/Target/X86/X86TargetTransformInfo.h =================================================================== --- lib/Target/X86/X86TargetTransformInfo.h +++ lib/Target/X86/X86TargetTransformInfo.h @@ -78,7 +78,8 @@ int getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, ArrayRef Args, FastMathFlags FMF); - int getReductionCost(unsigned Opcode, Type *Ty, bool IsPairwiseForm); + int getArithmeticReductionCost(unsigned Opcode, Type *Ty, + bool IsPairwiseForm); int getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef Indices, Index: lib/Target/X86/X86TargetTransformInfo.cpp =================================================================== --- lib/Target/X86/X86TargetTransformInfo.cpp +++ lib/Target/X86/X86TargetTransformInfo.cpp @@ -1697,8 +1697,8 @@ return BaseT::getAddressComputationCost(Ty, SE, Ptr); } -int X86TTIImpl::getReductionCost(unsigned Opcode, Type *ValTy, - bool IsPairwise) { +int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy, + bool IsPairwise) { std::pair LT = TLI->getTypeLegalizationCost(DL, ValTy); @@ -1766,7 +1766,7 @@ return LT.first * Entry->Cost; } - return BaseT::getReductionCost(Opcode, ValTy, IsPairwise); + return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwise); } /// \brief Calculate the cost of materializing a 64-bit value. This helper Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/NoFolder.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/IR/Verifier.h" @@ -45,6 +46,7 @@ #include using namespace llvm; +using namespace llvm::PatternMatch; using namespace slpvectorizer; #define SV_NAME "slp-vectorizer" @@ -4085,15 +4087,18 @@ return Changed; } -bool SLPVectorizerPass::tryToVectorize(BinaryOperator *V, BoUpSLP &R) { - if (!V) +bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) { + if (!I) return false; - Value *P = V->getParent(); + if (!isa(I) && !isa(I)) + return false; + + Value *P = I->getParent(); // Vectorize in current basic block only. - auto *Op0 = dyn_cast(V->getOperand(0)); - auto *Op1 = dyn_cast(V->getOperand(1)); + auto *Op0 = dyn_cast(I->getOperand(0)); + auto *Op1 = dyn_cast(I->getOperand(1)); if (!Op0 || !Op1 || Op0->getParent() != P || Op1->getParent() != P) return false; @@ -4187,7 +4192,89 @@ SmallVector ReductionOps; SmallVector ReducedVals; - BinaryOperator *ReductionRoot = nullptr; + /// Contains info about operation, like its opcode, left and right operands + + /// validity. + struct OperationData { + /// true if data is valid and represents one of the instructions. + bool Validity = false; + /// Opcode of the instruction. + unsigned Opcode = 0; + /// Left operand of the instruction. + Value *LHS = nullptr; + /// Right operand of the instruction. + Value *RHS = nullptr; + + public: + explicit OperationData() = default; + /// Construction for reduced values. They are identified by opcode only and + /// don't have associated LHS/RHS values. + explicit OperationData(Value *V) { + if (auto *I = dyn_cast(V)) { + Validity = true; + Opcode = I->getOpcode(); + } + } + /// Constructor for binary reduction operations with opcode and its left and + /// right operands. + OperationData(unsigned Opcode, Value *LHS, Value *RHS) + : Validity(true), Opcode(Opcode), LHS(LHS), RHS(RHS) {} + operator bool() const { return Validity; } + /// Get the index of the first operand of the reduction operation. + unsigned getFirstOperandIndex() const { + assert(Validity && LHS && RHS); + return 0; + } + /// Total number of operands in the reduction operation. + unsigned getNumberOfOperands() const { + assert(Validity && LHS && RHS); + return 2; + } + /// Expected number of uses for reduction operations/reduced values. + unsigned getRequiredNumberOfUses() const { + assert(Validity && LHS && RHS); + return 1; + } + /// Checks if instruction is associative and can be vectorized. + bool isAssociative(Instruction *I) const { + assert(Validity && LHS && RHS); + return I->isAssociative(); + } + /// Checks if the reduction operation can be vectorized. + bool isVectorizable() const { + return Validity && LHS && RHS && + // We currently only support adds and min/max. + (Opcode == Instruction::Add || Opcode == Instruction::FAdd); + } + /// Checks if two operation data represent reduction operation/reduced + /// value of the same kind. + bool operator==(const OperationData &OD) { + return this == &OD || (Validity == OD.Validity && (!LHS == !OD.LHS) && + (!RHS == !OD.RHS) && Opcode == OD.Opcode); + } + bool operator!=(const OperationData &OD) { return !(*this == OD); } + void clear() { + Validity = false; + LHS = nullptr; + RHS = nullptr; + Opcode = 0; + } + unsigned getOpcode() const { + assert(isVectorizable()); + return Opcode; + } + Value *getLHS() const { return LHS; } + Value *getRHS() const { return RHS; } + /// Creates reuction operation with the current opcode. + Value *createOp(IRBuilder<> &Builder, Value *L, Value *R, + const Twine &Name = "") const { + assert(Opcode == Instruction::FAdd || Opcode == Instruction::Add); + if (Opcode == Instruction::FAdd) + return Builder.CreateFAdd(L, R, Name); + return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, L, R, Name); + } + }; + + Instruction *ReductionRoot = nullptr; // After successfull horizontal reduction vectorization attempt for PHI node // vectorizer tries to update root binary op by combining vectorized tree and // the ReductionPHI node. But during vectorization this ReductionPHI can be @@ -4197,64 +4284,74 @@ // is destroyed" crash upon PHI node deletion. WeakVH ReductionPHI; - /// The opcode of the reduction. - Instruction::BinaryOps ReductionOpcode = Instruction::BinaryOpsEnd; - /// The opcode of the values we perform a reduction on. - unsigned ReducedValueOpcode = 0; + /// The operation data of the reduction operation. + OperationData ReductionData; + /// The operation data of the values we perform a reduction on. + OperationData ReducedValueData; /// Should we model this reduction as a pairwise reduction tree or a tree that /// splits the vector in halves and adds those halves. bool IsPairwiseReduction = false; + static OperationData getOperationData(Value *V) { + if (!V) + return OperationData(); + + Value *LHS; + Value *RHS; + if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V)) + return {cast(V)->getOpcode(), LHS, RHS}; + return OperationData(V); + } + public: HorizontalReduction() = default; /// \brief Try to find a reduction tree. - bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) { + bool matchAssociativeReduction(PHINode *Phi, Instruction *B) { assert((!Phi || is_contained(Phi->operands(), B)) && "Thi phi needs to use the binary operator"); + ReductionData = getOperationData(B); + // We could have a initial reductions that is not an add. // r *= v1 + v2 + v3 + v4 // In such a case start looking for a tree rooted in the first '+'. if (Phi) { - if (B->getOperand(0) == Phi) { + if (ReductionData.getLHS() == Phi) { Phi = nullptr; - B = dyn_cast(B->getOperand(1)); - } else if (B->getOperand(1) == Phi) { + B = dyn_cast(ReductionData.getRHS()); + ReductionData = getOperationData(B); + } else if (ReductionData.getRHS() == Phi) { Phi = nullptr; - B = dyn_cast(B->getOperand(0)); + B = dyn_cast(ReductionData.getLHS()); + ReductionData = getOperationData(B); } } - if (!B) + if (!ReductionData.isVectorizable() || !ReductionData.isAssociative(B)) return false; Type *Ty = B->getType(); if (!isValidElementType(Ty)) return false; - ReductionOpcode = B->getOpcode(); - ReducedValueOpcode = 0; + ReducedValueData.clear(); ReductionRoot = B; ReductionPHI = Phi; - // We currently only support adds. - if ((ReductionOpcode != Instruction::Add && - ReductionOpcode != Instruction::FAdd) || - !B->isAssociative()) - return false; - // Post order traverse the reduction tree starting at B. We only handle true - // trees containing only binary operators or selects. + // trees containing only binary operators. SmallVector, 32> Stack; - Stack.push_back(std::make_pair(B, 0)); + Stack.push_back(std::make_pair(B, ReductionData.getFirstOperandIndex())); + const unsigned NUses = ReductionData.getRequiredNumberOfUses(); while (!Stack.empty()) { Instruction *TreeN = Stack.back().first; unsigned EdgeToVist = Stack.back().second++; - bool IsReducedValue = TreeN->getOpcode() != ReductionOpcode; + OperationData OpData = getOperationData(TreeN); + bool IsReducedValue = OpData != ReductionData; // Postorder vist. - if (EdgeToVist == 2 || IsReducedValue) { + if (IsReducedValue || EdgeToVist == OpData.getNumberOfOperands()) { if (IsReducedValue) ReducedVals.push_back(TreeN); else @@ -4268,34 +4365,35 @@ Value *NextV = TreeN->getOperand(EdgeToVist); if (NextV != Phi) { auto *I = dyn_cast(NextV); + OpData = getOperationData(I); // Continue analysis if the next operand is a reduction operation or // (possibly) a reduced value. If the reduced value opcode is not set, // the first met operation != reduction operation is considered as the // reduced value class. - if (I && (!ReducedValueOpcode || I->getOpcode() == ReducedValueOpcode || - I->getOpcode() == ReductionOpcode)) { + if (I && (!ReducedValueData || OpData == ReducedValueData || + OpData == ReductionData)) { // Only handle trees in the current basic block. if (I->getParent() != B->getParent()) return false; - // Each tree node needs to have one user except for the ultimate - // reduction. - if (!I->hasOneUse() && I != B) + // Each tree node needs to have minimal number of users except for the + // ultimate reduction. + if (!I->hasNUses(NUses) && I != B) return false; - if (I->getOpcode() == ReductionOpcode) { + if (OpData == ReductionData) { // We need to be able to reassociate the reduction operations. - if (!I->isAssociative()) + if (!OpData.isAssociative(I)) return false; - } else if (ReducedValueOpcode && - ReducedValueOpcode != I->getOpcode()) { + } else if (ReducedValueData && + ReducedValueData != OpData) { // Make sure that the opcodes of the operations that we are going to // reduce match. return false; - } else if (!ReducedValueOpcode) - ReducedValueOpcode = I->getOpcode(); + } else if (!ReducedValueData) + ReducedValueData = OpData; - Stack.push_back(std::make_pair(I, 0)); + Stack.push_back(std::make_pair(I, OpData.getFirstOperandIndex())); continue; } return false; @@ -4356,8 +4454,8 @@ emitReduction(VectorizedRoot, Builder, ReduxWidth); if (VectorizedTree) { Builder.SetCurrentDebugLocation(Loc); - VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, - ReducedSubTree, "bin.rdx"); + VectorizedTree = ReductionData.createOp(Builder, VectorizedTree, + ReducedSubTree, "bin.rdx"); } else VectorizedTree = ReducedSubTree; i += ReduxWidth; @@ -4369,8 +4467,8 @@ for (; i < NumReducedVals; ++i) { Builder.SetCurrentDebugLocation( cast(ReducedVals[i])->getDebugLoc()); - VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, - ReducedVals[i]); + VectorizedTree = + ReductionData.createOp(Builder, VectorizedTree, ReducedVals[i]); } // Update users. if (ReductionPHI && !isa(ReductionPHI)) { @@ -4394,15 +4492,19 @@ Type *ScalarTy = FirstReducedVal->getType(); Type *VecTy = VectorType::get(ScalarTy, ReduxWidth); - int PairwiseRdxCost = TTI->getReductionCost(ReductionOpcode, VecTy, true); - int SplittingRdxCost = TTI->getReductionCost(ReductionOpcode, VecTy, false); + int PairwiseRdxCost = + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/true); + int SplittingRdxCost = + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/false); IsPairwiseReduction = PairwiseRdxCost < SplittingRdxCost; int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost; int ScalarReduxCost = (ReduxWidth - 1) * - TTI->getArithmeticInstrCost(ReductionOpcode, ScalarTy); + TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy); DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost << " for reduction that starts with " << *FirstReducedVal @@ -4433,14 +4535,14 @@ Value *RightShuf = Builder.CreateShuffleVector( TmpVec, UndefValue::get(TmpVec->getType()), (RightMask), "rdx.shuf.r"); - TmpVec = Builder.CreateBinOp(ReductionOpcode, LeftShuf, RightShuf, - "bin.rdx"); + TmpVec = + ReductionData.createOp(Builder, LeftShuf, RightShuf, "bin.rdx"); } else { Value *UpperHalf = createRdxShuffleMask(ReduxWidth, i, false, false, Builder); Value *Shuf = Builder.CreateShuffleVector( TmpVec, UndefValue::get(TmpVec->getType()), UpperHalf, "rdx.shuf"); - TmpVec = Builder.CreateBinOp(ReductionOpcode, TmpVec, Shuf, "bin.rdx"); + TmpVec = ReductionData.createOp(Builder, TmpVec, Shuf, "bin.rdx"); } } @@ -4611,10 +4713,10 @@ /// if it can be done. /// \returns true if a horizontal reduction was matched and reduced. /// \returns false if a horizontal reduction was not matched. -static bool canBeVectorized( - PHINode *P, Instruction *Root, BasicBlock *BB, BoUpSLP &R, - TargetTransformInfo *TTI, - const function_ref Vectorize) { +static bool +canBeVectorized(PHINode *P, Instruction *Root, BasicBlock *BB, BoUpSLP &R, + TargetTransformInfo *TTI, + const function_ref Vectorize) { if (!ShouldVectorizeHor) return false; @@ -4639,9 +4741,10 @@ } if (Stack.back().isInitial()) { Stack.back().clearInitial(); - if (auto *BI = dyn_cast(Inst)) { + auto *BI = dyn_cast(Inst); + if (BI) { HorizontalReduction HorRdx; - if (HorRdx.matchAssociativeReduction(P, BI)) { + if (HorRdx.matchAssociativeReduction(P, Inst)) { if (HorRdx.tryToReduce(R, TTI)) { Res = true; P = nullptr; @@ -4659,7 +4762,7 @@ } } P = nullptr; - if (Vectorize(dyn_cast(Inst), R)) { + if (Vectorize(Inst, R)) { Res = true; continue; } @@ -4690,8 +4793,8 @@ P = nullptr; // Try to match and vectorize a horizontal reduction. return canBeVectorized(P, I, BB, R, TTI, - [this](BinaryOperator *BI, BoUpSLP &R) -> bool { - return tryToVectorize(BI, R); + [this](Instruction *I, BoUpSLP &R) -> bool { + return tryToVectorize(I, R); }); } @@ -4800,27 +4903,16 @@ } // Try to vectorize trees that start at compare instructions. - if (CmpInst *CI = dyn_cast(it)) { - if (tryToVectorizePair(CI->getOperand(0), CI->getOperand(1), R)) { + if (auto *BI = dyn_cast(it)) { + if (!BI->isConditional()) + continue; + + if (vectorizeRootInstruction(nullptr, BI->getCondition(), BB, R, TTI)) { Changed = true; - // We would like to start over since some instructions are deleted - // and the iterator may become invalid value. it = BB->begin(); e = BB->end(); continue; } - - for (int I = 0; I < 2; ++I) { - if (vectorizeRootInstruction(nullptr, CI->getOperand(I), BB, R, TTI)) { - Changed = true; - // We would like to start over since some instructions are deleted - // and the iterator may become invalid value. - it = BB->begin(); - e = BB->end(); - break; - } - } - continue; } // Try to vectorize trees that start at insertelement instructions.