Index: llvm/trunk/lib/Analysis/CostModel.cpp =================================================================== --- llvm/trunk/lib/Analysis/CostModel.cpp +++ llvm/trunk/lib/Analysis/CostModel.cpp @@ -24,12 +24,14 @@ #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; +using namespace PatternMatch; #define CM_NAME "cost-model" #define DEBUG_TYPE CM_NAME @@ -183,27 +185,46 @@ return Mask == ActualMask; } -static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp, - unsigned Level, unsigned NumLevels) { +namespace { +/// Contains opcode + LHS/RHS parts of the reduction operations. +struct ReductionData { + explicit ReductionData() = default; + ReductionData(unsigned Opcode, Value *LHS, Value *RHS) + : Opcode(Opcode), LHS(LHS), RHS(RHS) {} + unsigned Opcode = 0; + Value *LHS = nullptr; + Value *RHS = nullptr; +}; +} // namespace + +static Optional getReductionData(Instruction *I) { + Value *L, *R; + if (m_BinOp(m_Value(L), m_Value(R)).match(I)) + return ReductionData(I->getOpcode(), L, R); + return llvm::None; +} + +static bool matchPairwiseReductionAtLevel(Instruction *I, unsigned Level, + unsigned NumLevels) { // Match one level of pairwise operations. // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, // <4 x i32> // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, // <4 x i32> // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 - if (BinOp == nullptr) + if (!I) return false; - assert(BinOp->getType()->isVectorTy() && "Expecting a vector type"); + assert(I->getType()->isVectorTy() && "Expecting a vector type"); - unsigned Opcode = BinOp->getOpcode(); - Value *L = BinOp->getOperand(0); - Value *R = BinOp->getOperand(1); + Optional RD = getReductionData(I); + if (!RD) + return false; - ShuffleVectorInst *LS = dyn_cast(L); + ShuffleVectorInst *LS = dyn_cast(RD->LHS); if (!LS && Level) return false; - ShuffleVectorInst *RS = dyn_cast(R); + ShuffleVectorInst *RS = dyn_cast(RD->RHS); if (!RS && Level) return false; @@ -228,31 +249,30 @@ // Example: // %NextLevelOpL = shufflevector %R, <1, undef ...> // %BinOp = fadd %NextLevelOpL, %R - if (NextLevelOpL && NextLevelOpL != R) + if (NextLevelOpL && NextLevelOpL != RD->RHS) return false; - else if (NextLevelOpR && NextLevelOpR != L) + else if (NextLevelOpR && NextLevelOpR != RD->LHS) return false; - NextLevelOp = NextLevelOpL ? R : L; + NextLevelOp = NextLevelOpL ? RD->RHS : RD->LHS; } else return false; // Check that the next levels binary operation exists and matches with the // current one. - BinaryOperator *NextLevelBinOp = nullptr; if (Level + 1 != NumLevels) { - if (!(NextLevelBinOp = dyn_cast(NextLevelOp))) - return false; - else if (NextLevelBinOp->getOpcode() != Opcode) + Optional NextLevelRD = + getReductionData(cast(NextLevelOp)); + if (!NextLevelRD || RD->Opcode != NextLevelRD->Opcode) return false; } // Shuffle mask for pairwise operation must match. - if (matchPairwiseShuffleMask(LS, true, Level)) { - if (!matchPairwiseShuffleMask(RS, false, Level)) + if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) { + if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level)) return false; - } else if (matchPairwiseShuffleMask(RS, true, Level)) { - if (!matchPairwiseShuffleMask(LS, false, Level)) + } else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) { + if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level)) return false; } else return false; @@ -261,7 +281,8 @@ return true; // Match next level. - return matchPairwiseReductionAtLevel(NextLevelBinOp, Level, NumLevels); + return matchPairwiseReductionAtLevel(cast(NextLevelOp), Level, + NumLevels); } static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, @@ -277,11 +298,14 @@ if (Idx != 0) return false; - BinaryOperator *RdxStart = dyn_cast(ReduxRoot->getOperand(0)); + auto *RdxStart = dyn_cast(ReduxRoot->getOperand(0)); if (!RdxStart) return false; + Optional RD = getReductionData(RdxStart); + if (!RD) + return false; - Type *VecTy = ReduxRoot->getOperand(0)->getType(); + Type *VecTy = RdxStart->getType(); unsigned NumVecElems = VecTy->getVectorNumElements(); if (!isPowerOf2_32(NumVecElems)) return false; @@ -307,17 +331,14 @@ if (!matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems))) return false; - Opcode = RdxStart->getOpcode(); + Opcode = RD->Opcode; Ty = VecTy; return true; } static std::pair -getShuffleAndOtherOprd(BinaryOperator *B) { - - Value *L = B->getOperand(0); - Value *R = B->getOperand(1); +getShuffleAndOtherOprd(Value *L, Value *R) { ShuffleVectorInst *S = nullptr; if ((S = dyn_cast(L))) @@ -340,10 +361,12 @@ if (Idx != 0) return false; - BinaryOperator *RdxStart = dyn_cast(ReduxRoot->getOperand(0)); + auto *RdxStart = dyn_cast(ReduxRoot->getOperand(0)); if (!RdxStart) return false; - unsigned RdxOpcode = RdxStart->getOpcode(); + Optional RD = getReductionData(RdxStart); + if (!RD) + return false; Type *VecTy = ReduxRoot->getOperand(0)->getType(); unsigned NumVecElems = VecTy->getVectorNumElements(); @@ -362,20 +385,21 @@ // %r = extractelement <4 x float> %bin.rdx8, i32 0 unsigned MaskStart = 1; - Value *RdxOp = RdxStart; + Instruction *RdxOp = RdxStart; SmallVector ShuffleMask(NumVecElems, 0); unsigned NumVecElemsRemain = NumVecElems; while (NumVecElemsRemain - 1) { // Check for the right reduction operation. - BinaryOperator *BinOp; - if (!(BinOp = dyn_cast(RdxOp))) + if (!RdxOp) return false; - if (BinOp->getOpcode() != RdxOpcode) + Optional RDLevel = getReductionData(RdxOp); + if (!RDLevel || RDLevel->Opcode != RD->Opcode) return false; Value *NextRdxOp; ShuffleVectorInst *Shuffle; - std::tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(BinOp); + std::tie(NextRdxOp, Shuffle) = + getShuffleAndOtherOprd(RDLevel->LHS, RDLevel->RHS); // Check the current reduction operation and the shuffle use the same value. if (Shuffle == nullptr) @@ -393,12 +417,12 @@ if (ShuffleMask != Mask) return false; - RdxOp = NextRdxOp; + RdxOp = dyn_cast(NextRdxOp); NumVecElemsRemain /= 2; MaskStart *= 2; } - Opcode = RdxOpcode; + Opcode = RD->Opcode; Ty = VecTy; return true; } @@ -495,10 +519,14 @@ unsigned ReduxOpCode; Type *ReduxType; - if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) - return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, false); - else if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) - return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, true); + if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) { + return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, + /*IsPairwiseForm=*/false); + } + if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) { + return TTI->getArithmeticReductionCost(ReduxOpCode, ReduxType, + /*IsPairwiseForm=*/true); + } return TTI->getVectorInstrCost(I->getOpcode(), EEI->getOperand(0)->getType(), Idx); Index: llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/NoFolder.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/IR/Verifier.h" @@ -48,6 +49,7 @@ #include using namespace llvm; +using namespace llvm::PatternMatch; using namespace slpvectorizer; #define SV_NAME "slp-vectorizer" @@ -4321,12 +4323,104 @@ // Use map vector to make stable output. MapVector ExtraArgs; - BinaryOperator *ReductionRoot = nullptr; + /// Contains info about operation, like its opcode, left and right operands. + struct OperationData { + /// true if the operation is a reduced value, false if reduction operation. + bool IsReducedValue = false; + /// Opcode of the instruction. + unsigned Opcode = 0; + /// Left operand of the reduction operation. + Value *LHS = nullptr; + /// Right operand of the reduction operation. + Value *RHS = nullptr; + + /// Checks if the reduction operation can be vectorized. + bool isVectorizable() const { + return LHS && RHS && + // We currently only support adds. + (Opcode == Instruction::Add || Opcode == Instruction::FAdd); + } + + public: + explicit OperationData() = default; + /// Construction for reduced values. They are identified by opcode only and + /// don't have associated LHS/RHS values. + explicit OperationData(Value *V) : IsReducedValue(true) { + if (auto *I = dyn_cast(V)) + Opcode = I->getOpcode(); + } + /// Constructor for binary reduction operations with opcode and its left and + /// right operands. + OperationData(unsigned Opcode, Value *LHS, Value *RHS) + : IsReducedValue(false), Opcode(Opcode), LHS(LHS), RHS(RHS) {} + explicit operator bool() const { return Opcode; } + /// Get the index of the first operand. + unsigned getFirstOperandIndex() const { + assert(!!*this && "The opcode is not set."); + return 0; + } + /// Total number of operands in the reduction operation. + unsigned getNumberOfOperands() const { + assert(!IsReducedValue && !!*this && LHS && RHS && + "Expected reduction operation."); + return 2; + } + /// Expected number of uses for reduction operations/reduced values. + unsigned getRequiredNumberOfUses() const { + assert(!IsReducedValue && !!*this && LHS && RHS && + "Expected reduction operation."); + return 1; + } + /// Checks if instruction is associative and can be vectorized. + bool isAssociative(Instruction *I) const { + assert(!IsReducedValue && *this && LHS && RHS && + "Expected reduction operation."); + return I->isAssociative(); + } + /// Checks if the reduction operation can be vectorized. + bool isVectorizable(Instruction *I) const { + return isVectorizable() && isAssociative(I); + } + + /// Checks if two operation data are both a reduction op or both a reduced + /// value. + bool operator==(const OperationData &OD) { + assert((IsReducedValue != OD.IsReducedValue) || + ((!LHS == !OD.LHS) && (!RHS == !OD.RHS)) && + "One of the comparing operations is incorrect."); + return this == &OD || + (IsReducedValue == OD.IsReducedValue && Opcode == OD.Opcode); + } + bool operator!=(const OperationData &OD) { return !(*this == OD); } + void clear() { + IsReducedValue = false; + Opcode = 0; + LHS = nullptr; + RHS = nullptr; + } + /// Get the opcode of the reduction operation. + unsigned getOpcode() const { + assert(isVectorizable() && "Expected vectorizable operation."); + return Opcode; + } + Value *getLHS() const { return LHS; } + Value *getRHS() const { return RHS; } + /// Creates reduction operation with the current opcode. + Value *createOp(IRBuilder<> &Builder, const Twine &Name = "") const { + assert(!IsReducedValue && + (Opcode == Instruction::FAdd || Opcode == Instruction::Add) && + "Expected add|fadd reduction operation."); + return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS, + Name); + } + }; + + Instruction *ReductionRoot = nullptr; - /// The opcode of the reduction. - Instruction::BinaryOps ReductionOpcode = Instruction::BinaryOpsEnd; - /// The opcode of the values we perform a reduction on. - unsigned ReducedValueOpcode = 0; + /// The operation data of the reduction operation. + OperationData ReductionData; + /// The operation data of the values we perform a reduction on. + OperationData ReducedValueData; /// Should we model this reduction as a pairwise reduction tree or a tree that /// splits the vector in halves and adds those halves. bool IsPairwiseReduction = false; @@ -4351,55 +4445,65 @@ } } + static OperationData getOperationData(Value *V) { + if (!V) + return OperationData(); + + Value *LHS; + Value *RHS; + if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V)) + return OperationData(cast(V)->getOpcode(), LHS, RHS); + return OperationData(V); + } + public: HorizontalReduction() = default; /// \brief Try to find a reduction tree. - bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) { + bool matchAssociativeReduction(PHINode *Phi, Instruction *B) { assert((!Phi || is_contained(Phi->operands(), B)) && "Thi phi needs to use the binary operator"); + ReductionData = getOperationData(B); + // We could have a initial reductions that is not an add. // r *= v1 + v2 + v3 + v4 // In such a case start looking for a tree rooted in the first '+'. if (Phi) { - if (B->getOperand(0) == Phi) { + if (ReductionData.getLHS() == Phi) { Phi = nullptr; - B = dyn_cast(B->getOperand(1)); - } else if (B->getOperand(1) == Phi) { + B = dyn_cast(ReductionData.getRHS()); + ReductionData = getOperationData(B); + } else if (ReductionData.getRHS() == Phi) { Phi = nullptr; - B = dyn_cast(B->getOperand(0)); + B = dyn_cast(ReductionData.getLHS()); + ReductionData = getOperationData(B); } } - if (!B) + if (!ReductionData.isVectorizable(B)) return false; Type *Ty = B->getType(); if (!isValidElementType(Ty)) return false; - ReductionOpcode = B->getOpcode(); - ReducedValueOpcode = 0; + ReducedValueData.clear(); ReductionRoot = B; - // We currently only support adds. - if ((ReductionOpcode != Instruction::Add && - ReductionOpcode != Instruction::FAdd) || - !B->isAssociative()) - return false; - // Post order traverse the reduction tree starting at B. We only handle true - // trees containing only binary operators or selects. + // trees containing only binary operators. SmallVector, 32> Stack; - Stack.push_back(std::make_pair(B, 0)); + Stack.push_back(std::make_pair(B, ReductionData.getFirstOperandIndex())); + const unsigned NUses = ReductionData.getRequiredNumberOfUses(); while (!Stack.empty()) { Instruction *TreeN = Stack.back().first; unsigned EdgeToVist = Stack.back().second++; - bool IsReducedValue = TreeN->getOpcode() != ReductionOpcode; + OperationData OpData = getOperationData(TreeN); + bool IsReducedValue = OpData != ReductionData; // Postorder vist. - if (EdgeToVist == 2 || IsReducedValue) { + if (IsReducedValue || EdgeToVist == OpData.getNumberOfOperands()) { if (IsReducedValue) ReducedVals.push_back(TreeN); else { @@ -4428,12 +4532,13 @@ Value *NextV = TreeN->getOperand(EdgeToVist); if (NextV != Phi) { auto *I = dyn_cast(NextV); + OpData = getOperationData(I); // Continue analysis if the next operand is a reduction operation or // (possibly) a reduced value. If the reduced value opcode is not set, // the first met operation != reduction operation is considered as the // reduced value class. - if (I && (!ReducedValueOpcode || I->getOpcode() == ReducedValueOpcode || - I->getOpcode() == ReductionOpcode)) { + if (I && (!ReducedValueData || OpData == ReducedValueData || + OpData == ReductionData)) { // Only handle trees in the current basic block. if (I->getParent() != B->getParent()) { // I is an extra argument for TreeN (its parent operation). @@ -4441,32 +4546,32 @@ continue; } - // Each tree node needs to have one user except for the ultimate - // reduction. - if (!I->hasOneUse() && I != B) { + // Each tree node needs to have minimal number of users except for the + // ultimate reduction. + if (!I->hasNUses(NUses) && I != B) { // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; } - if (I->getOpcode() == ReductionOpcode) { + if (OpData == ReductionData) { // We need to be able to reassociate the reduction operations. - if (!I->isAssociative()) { + if (!OpData.isAssociative(I)) { // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; } - } else if (ReducedValueOpcode && - ReducedValueOpcode != I->getOpcode()) { + } else if (ReducedValueData && + ReducedValueData != OpData) { // Make sure that the opcodes of the operations that we are going to // reduce match. // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; - } else if (!ReducedValueOpcode) - ReducedValueOpcode = I->getOpcode(); + } else if (!ReducedValueData) + ReducedValueData = OpData; - Stack.push_back(std::make_pair(I, 0)); + Stack.push_back(std::make_pair(I, OpData.getFirstOperandIndex())); continue; } } @@ -4539,8 +4644,9 @@ emitReduction(VectorizedRoot, Builder, ReduxWidth, ReductionOps, TTI); if (VectorizedTree) { Builder.SetCurrentDebugLocation(Loc); - VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, - ReducedSubTree, "bin.rdx"); + OperationData VectReductionData(ReductionData.getOpcode(), + VectorizedTree, ReducedSubTree); + VectorizedTree = VectReductionData.createOp(Builder, "bin.rdx"); propagateIRFlags(VectorizedTree, ReductionOps); } else VectorizedTree = ReducedSubTree; @@ -4553,8 +4659,9 @@ for (; i < NumReducedVals; ++i) { auto *I = cast(ReducedVals[i]); Builder.SetCurrentDebugLocation(I->getDebugLoc()); - VectorizedTree = - Builder.CreateBinOp(ReductionOpcode, VectorizedTree, I); + OperationData VectReductionData(ReductionData.getOpcode(), + VectorizedTree, I); + VectorizedTree = VectReductionData.createOp(Builder); propagateIRFlags(VectorizedTree, ReductionOps); } for (auto &Pair : ExternallyUsedValues) { @@ -4563,8 +4670,9 @@ // Add each externally used value to the final reduction. for (auto *I : Pair.second) { Builder.SetCurrentDebugLocation(I->getDebugLoc()); - VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, - Pair.first, "bin.extra"); + OperationData VectReductionData(ReductionData.getOpcode(), + VectorizedTree, Pair.first); + VectorizedTree = VectReductionData.createOp(Builder, "bin.extra"); propagateIRFlags(VectorizedTree, I); } } @@ -4586,16 +4694,18 @@ Type *VecTy = VectorType::get(ScalarTy, ReduxWidth); int PairwiseRdxCost = - TTI->getArithmeticReductionCost(ReductionOpcode, VecTy, true); + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/true); int SplittingRdxCost = - TTI->getArithmeticReductionCost(ReductionOpcode, VecTy, false); + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/false); IsPairwiseReduction = PairwiseRdxCost < SplittingRdxCost; int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost; int ScalarReduxCost = (ReduxWidth - 1) * - TTI->getArithmeticInstrCost(ReductionOpcode, ScalarTy); + TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy); DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost << " for reduction that starts with " << *FirstReducedVal @@ -4616,7 +4726,7 @@ if (!IsPairwiseReduction) return createSimpleTargetReduction( - Builder, TTI, ReductionOpcode, VectorizedValue, + Builder, TTI, ReductionData.getOpcode(), VectorizedValue, TargetTransformInfo::ReductionFlags(), RedOps); Value *TmpVec = VectorizedValue; @@ -4631,8 +4741,9 @@ Value *RightShuf = Builder.CreateShuffleVector( TmpVec, UndefValue::get(TmpVec->getType()), (RightMask), "rdx.shuf.r"); - TmpVec = - Builder.CreateBinOp(ReductionOpcode, LeftShuf, RightShuf, "bin.rdx"); + OperationData VectReductionData(ReductionData.getOpcode(), LeftShuf, + RightShuf); + TmpVec = VectReductionData.createOp(Builder, "bin.rdx"); propagateIRFlags(TmpVec, RedOps); }