Index: llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -15,21 +15,22 @@ // "Loop-Aware SLP in GCC" by Ira Rosen, Dorit Nuzman, Ayal Zaks. // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Vectorize.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DemandedBits.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" @@ -44,7 +45,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Analysis/VectorUtils.h" +#include "llvm/Transforms/Vectorize.h" #include #include #include @@ -363,11 +364,12 @@ BoUpSLP(Function *Func, ScalarEvolution *Se, TargetTransformInfo *Tti, TargetLibraryInfo *TLi, AliasAnalysis *Aa, LoopInfo *Li, - DominatorTree *Dt, AssumptionCache *AC) + DominatorTree *Dt, AssumptionCache *AC, DemandedBits *DB) : NumLoadsWantToKeepOrder(0), NumLoadsWantToChangeOrder(0), F(Func), - SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), + SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), AC(AC), DB(DB), Builder(Se->getContext()) { CodeMetrics::collectEphemeralValues(F, AC, EphValues); + MaxRequiredIntegerTy = nullptr; } /// \brief Vectorize the tree that starts with the elements in \p VL. @@ -399,6 +401,7 @@ BlockScheduling *BS = Iter.second.get(); BS->clear(); } + MaxRequiredIntegerTy = nullptr; } /// \returns true if the memory operations A and B are consecutive. @@ -419,6 +422,10 @@ /// vectorization factors. unsigned getVectorElementSize(Value *V); + /// Compute the maximum width integer type required to represent the result + /// of a scalar expression, if such a type exists. + void computeMaxRequiredIntegerTy(); + private: struct TreeEntry; @@ -924,8 +931,13 @@ AliasAnalysis *AA; LoopInfo *LI; DominatorTree *DT; + AssumptionCache *AC; + DemandedBits *DB; /// Instruction builder to construct the vectorized tree. IRBuilder<> Builder; + + // The maximum width integer type required to represent a scalar expression. + IntegerType *MaxRequiredIntegerTy; }; #ifndef NDEBUG @@ -1481,6 +1493,15 @@ ScalarTy = SI->getValueOperand()->getType(); VectorType *VecTy = VectorType::get(ScalarTy, VL.size()); + // If we have computed a smaller type for the expression, update VecTy so + // that the costs will be accurate. + if (MaxRequiredIntegerTy) { + auto *IT = dyn_cast(ScalarTy); + assert(IT && "Computed smaller type for non-integer value?"); + if (MaxRequiredIntegerTy->getBitWidth() < IT->getBitWidth()) + VecTy = VectorType::get(MaxRequiredIntegerTy, VL.size()); + } + if (E->NeedToGather) { if (allConstant(VL)) return 0; @@ -1809,9 +1830,17 @@ if (EphValues.count(EU.User)) continue; - VectorType *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth); - ExtractCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, - EU.Lane); + // If we plan to rewrite the tree in a smaller type, we will need to sign + // extend the extracted value back to the original type. Here, we account + // for the extract and the added cost of the sign extend if needed. + auto *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth); + if (MaxRequiredIntegerTy) { + VecTy = VectorType::get(MaxRequiredIntegerTy, BundleWidth); + ExtractCost += TTI->getCastInstrCost( + Instruction::SExt, EU.Scalar->getType(), MaxRequiredIntegerTy); + } + ExtractCost += + TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane); } Cost += getSpillCost(); @@ -2566,7 +2595,19 @@ } Builder.SetInsertPoint(&F->getEntryBlock().front()); - vectorizeTree(&VectorizableTree[0]); + auto *VectorRoot = vectorizeTree(&VectorizableTree[0]); + + // If the vectorized tree can be rewritten in a smaller type, we truncate the + // vectorized root. InstCombine will then rewrite the entire expression. We + // sign extend the extracted values below. + if (MaxRequiredIntegerTy) { + BasicBlock::iterator I(cast(VectorRoot)); + Builder.SetInsertPoint(&*++I); + auto BundleWidth = VectorizableTree[0].Scalars.size(); + auto *SmallerTy = VectorType::get(MaxRequiredIntegerTy, BundleWidth); + auto *Trunc = Builder.CreateTrunc(VectorRoot, SmallerTy); + VectorizableTree[0].VectorizedValue = Trunc; + } DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() << " values .\n"); @@ -2599,6 +2640,8 @@ if (PH->getIncomingValue(i) == Scalar) { Builder.SetInsertPoint(PH->getIncomingBlock(i)->getTerminator()); Value *Ex = Builder.CreateExtractElement(Vec, Lane); + if (MaxRequiredIntegerTy) + Ex = Builder.CreateSExt(Ex, Scalar->getType()); CSEBlocks.insert(PH->getIncomingBlock(i)); PH->setOperand(i, Ex); } @@ -2606,12 +2649,16 @@ } else { Builder.SetInsertPoint(cast(User)); Value *Ex = Builder.CreateExtractElement(Vec, Lane); + if (MaxRequiredIntegerTy) + Ex = Builder.CreateSExt(Ex, Scalar->getType()); CSEBlocks.insert(cast(User)->getParent()); User->replaceUsesOfWith(Scalar, Ex); } } else { Builder.SetInsertPoint(&F->getEntryBlock().front()); Value *Ex = Builder.CreateExtractElement(Vec, Lane); + if (MaxRequiredIntegerTy) + Ex = Builder.CreateSExt(Ex, Scalar->getType()); CSEBlocks.insert(&F->getEntryBlock()); User->replaceUsesOfWith(Scalar, Ex); } @@ -3180,7 +3227,7 @@ // If the current instruction is a load, update MaxWidth to reflect the // width of the loaded value. else if (isa(I)) - MaxWidth = std::max(MaxWidth, (unsigned)DL.getTypeSizeInBits(Ty)); + MaxWidth = std::max(MaxWidth, DL.getTypeSizeInBits(Ty)); // Otherwise, we need to visit the operands of the instruction. We only // handle the interesting cases from buildTree here. If an operand is an @@ -3207,6 +3254,85 @@ return MaxWidth; } +void BoUpSLP::computeMaxRequiredIntegerTy() { + + // If there are no external uses, the expression tree must be rooted by a + // store. We can't demote in-memory values, so there is nothing to do here. + if (ExternalUses.empty()) + return; + + // If the expression is not rooted by a store, these roots should have + // external uses. We will rely on InstCombine to rewrite the expression in + // the narrower type. However, InstCombine only rewrites single-use values. + // This means that if a tree entry other than a root is used externally, it + // must have multiple uses and InstCombine will not rewrite it. The code + // below ensures that only the roots are used externally. + auto &TreeRoot = VectorizableTree[0].Scalars; + SmallPtrSet ScalarRoots(TreeRoot.begin(), TreeRoot.end()); + for (auto &EU : ExternalUses) + if (!ScalarRoots.erase(EU.Scalar)) + return; + if (!ScalarRoots.empty()) + return; + + // The maximum bit width required to represent all the instructions in the + // tree without loss of precision. It would be safe to truncate the + // expression to this width. + auto MaxBitWidth = 8u; + + // We first check if all the bits of the root are demanded. If they're not, + // we can truncate the root to this narrower type. + auto *Root = dyn_cast(TreeRoot[0]); + if (!Root || !isa(Root->getType()) || !Root->hasOneUse()) + return; + auto Mask = DB->getDemandedBits(Root); + if (Mask.countLeadingZeros() > 0) + MaxBitWidth = Mask.getBitWidth() - Mask.countLeadingZeros(); + + // If all the bits of the root are demanded, we can try a little harder to + // compute a narrower type. This can happen, for example, if the roots are + // getelementptr indices. InstCombine promotes these indices to the pointer + // width. Thus, all their bits are technically demanded even though the + // address computation might be vectorized in a smaller type. We start by + // looking at each entry in the tree. + else + for (auto &Entry : VectorizableTree) { + + // Get a representative value for the vectorizable bundle. All values in + // Entry.Scalars should be isomorphic. + auto *Scalar = Entry.Scalars[0]; + + // If the scalar is used more than once, InstCombine will not rewrite it, + // so we should give up. + if (!Scalar->hasOneUse()) + return; + + // We only compute smaller integer types. If the scalar has a different + // type, give up. + auto *IT = dyn_cast(Scalar->getType()); + if (!IT) + return; + + // Compute the maximum bit width required to store the scalar. We use + // ValueTracking to compute the number of high-order bits we can + // truncate. We then round up to the next power-of-two. + auto &DL = F->getParent()->getDataLayout(); + auto NumSignBits = ComputeNumSignBits(Scalar, DL, 0, AC, 0, DT); + auto NumTypeBits = IT->getBitWidth(); + MaxBitWidth = std::max(NumTypeBits - NumSignBits, MaxBitWidth); + } + + // Round up to the next power-of-two. + if (!isPowerOf2_64(MaxBitWidth)) + MaxBitWidth = NextPowerOf2(MaxBitWidth); + + // If the maximum bit width we compute is less than the with of the roots' + // type, we can proceed with the narrowing. Otherwise, do nothing. + auto *RootIT = cast(TreeRoot[0]->getType()); + if (MaxBitWidth > 0 && MaxBitWidth < RootIT->getBitWidth()) + MaxRequiredIntegerTy = IntegerType::get(F->getContext(), MaxBitWidth); +} + /// The SLPVectorizer Pass. struct SLPVectorizer : public FunctionPass { typedef SmallVector StoreList; @@ -3228,6 +3354,7 @@ LoopInfo *LI; DominatorTree *DT; AssumptionCache *AC; + DemandedBits *DB; bool runOnFunction(Function &F) override { if (skipOptnoneFunction(F)) @@ -3241,6 +3368,7 @@ LI = &getAnalysis().getLoopInfo(); DT = &getAnalysis().getDomTree(); AC = &getAnalysis().getAssumptionCache(F); + DB = &getAnalysis(); Stores.clear(); GEPs.clear(); @@ -3270,7 +3398,7 @@ // Use the bottom up slp vectorizer to construct chains that start with // store instructions. - BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC); + BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC, DB); // A general note: the vectorizer must use BoUpSLP::eraseInstruction() to // delete instructions. @@ -3313,6 +3441,7 @@ AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addRequired(); AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); @@ -3417,6 +3546,7 @@ ArrayRef Operands = Chain.slice(i, VF); R.buildTree(Operands); + R.computeMaxRequiredIntegerTy(); int Cost = R.getTreeCost(); @@ -3616,6 +3746,7 @@ Value *ReorderedOps[] = { Ops[1], Ops[0] }; R.buildTree(ReorderedOps, None); } + R.computeMaxRequiredIntegerTy(); int Cost = R.getTreeCost(); if (Cost < -SLPCostThreshold) { @@ -3882,6 +4013,7 @@ for (; i < NumReducedVals - ReduxWidth + 1; i += ReduxWidth) { V.buildTree(makeArrayRef(&ReducedVals[i], ReduxWidth), ReductionOps); + V.computeMaxRequiredIntegerTy(); // Estimate cost. int Cost = V.getTreeCost() + getReductionCost(TTI, ReducedVals[i]); Index: llvm/trunk/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll =================================================================== --- llvm/trunk/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll +++ llvm/trunk/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll @@ -1,4 +1,5 @@ -; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s +; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s --check-prefix=PROFITABLE +; RUN: opt -S -slp-vectorizer -slp-threshold=-12 -dce -instcombine < %s | FileCheck %s --check-prefix=UNPROFITABLE target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" target triple = "aarch64--linux-gnu" @@ -18,13 +19,13 @@ ; return sum; ; } -; CHECK-LABEL: @gather_reduce_8x16_i32 +; PROFITABLE-LABEL: @gather_reduce_8x16_i32 ; -; CHECK: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> -; CHECK: zext <8 x i16> [[L]] to <8 x i32> -; CHECK: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> -; CHECK: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] -; CHECK: sext i32 [[X]] to i64 +; PROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> +; PROFITABLE: zext <8 x i16> [[L]] to <8 x i32> +; PROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> +; PROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] +; PROFITABLE: sext i32 [[X]] to i64 ; define i32 @gather_reduce_8x16_i32(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) { entry: @@ -137,14 +138,18 @@ br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body } -; CHECK-LABEL: @gather_reduce_8x16_i64 +; UNPROFITABLE-LABEL: @gather_reduce_8x16_i64 ; -; CHECK-NOT: load <8 x i16> +; UNPROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> +; UNPROFITABLE: zext <8 x i16> [[L]] to <8 x i32> +; UNPROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32> +; UNPROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] +; UNPROFITABLE: sext i32 [[X]] to i64 ; -; FIXME: We are currently unable to vectorize the case with i64 subtraction -; because the zero extensions are too expensive. The solution here is to -; convert the i64 subtractions to i32 subtractions during vectorization. -; This would then match the case above. +; TODO: Although we can now vectorize this case while converting the i64 +; subtractions to i32, the cost model currently finds vectorization to be +; unprofitable. The cost model is penalizing the sign and zero +; extensions in the vectorized version, but they are actually free. ; define i32 @gather_reduce_8x16_i64(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) { entry: