Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -365,9 +365,10 @@ TargetLibraryInfo *TLi, AliasAnalysis *Aa, LoopInfo *Li, DominatorTree *Dt, AssumptionCache *AC) : NumLoadsWantToKeepOrder(0), NumLoadsWantToChangeOrder(0), F(Func), - SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), + SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), AC(AC), Builder(Se->getContext()) { CodeMetrics::collectEphemeralValues(F, AC, EphValues); + MaxRequiredIntegerTy = nullptr; } /// \brief Vectorize the tree that starts with the elements in \p VL. @@ -399,6 +400,7 @@ BlockScheduling *BS = Iter.second.get(); BS->clear(); } + MaxRequiredIntegerTy = nullptr; } /// \returns true if the memory operations A and B are consecutive. @@ -417,6 +419,10 @@ /// calculate vectorization factors. unsigned getVectorElementSize(Value *V); + /// Compute the maximum width integer type required to represent the result + /// of a scalar expression, if such a type exists. + void computeMaxRequiredIntegerTy(); + private: struct TreeEntry; @@ -922,8 +928,12 @@ AliasAnalysis *AA; LoopInfo *LI; DominatorTree *DT; + AssumptionCache *AC; /// Instruction builder to construct the vectorized tree. IRBuilder<> Builder; + + // The maximum width integer type required to represent a scalar expression. + IntegerType *MaxRequiredIntegerTy; }; #ifndef NDEBUG @@ -1479,6 +1489,15 @@ ScalarTy = SI->getValueOperand()->getType(); VectorType *VecTy = VectorType::get(ScalarTy, VL.size()); + // If we have computed a smaller type for the expression, update VecTy so + // that the costs will be accurate. + if (MaxRequiredIntegerTy) { + auto *IT = dyn_cast(ScalarTy); + assert(IT && "Computed smaller type for non-integer value?"); + if (MaxRequiredIntegerTy->getBitWidth() < IT->getBitWidth()) + VecTy = VectorType::get(MaxRequiredIntegerTy, VL.size()); + } + if (E->NeedToGather) { if (allConstant(VL)) return 0; @@ -1808,9 +1827,17 @@ if (EphValues.count(I->User)) continue; - VectorType *VecTy = VectorType::get(I->Scalar->getType(), BundleWidth); - ExtractCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, - I->Lane); + // If we plan to rewrite the tree in a smaller type, we will need to sign + // extend the extracted value back to the original type. Here, we account + // for the extract and the added cost of the sign extend if needed. + auto *VecTy = VectorType::get(I->Scalar->getType(), BundleWidth); + if (MaxRequiredIntegerTy) { + VecTy = VectorType::get(MaxRequiredIntegerTy, BundleWidth); + ExtractCost += TTI->getCastInstrCost( + Instruction::SExt, I->Scalar->getType(), MaxRequiredIntegerTy); + } + ExtractCost += + TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, I->Lane); } Cost += getSpillCost(); @@ -2565,7 +2592,19 @@ } Builder.SetInsertPoint(&F->getEntryBlock().front()); - vectorizeTree(&VectorizableTree[0]); + auto *VectorRoot = vectorizeTree(&VectorizableTree[0]); + + // If the vectorized tree can be rewritten in a smaller type, we truncate the + // vectorized root. InstCombine will then rewrite the entire expression. We + // sign extend the extracted values below. + if (MaxRequiredIntegerTy) { + BasicBlock::iterator I(cast(VectorRoot)); + Builder.SetInsertPoint(&*++I); + auto BundleWidth = VectorizableTree[0].Scalars.size(); + auto *SmallerTy = VectorType::get(MaxRequiredIntegerTy, BundleWidth); + auto *Trunc = Builder.CreateTrunc(VectorRoot, SmallerTy); + VectorizableTree[0].VectorizedValue = Trunc; + } DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() << " values .\n"); @@ -2598,6 +2637,8 @@ if (PH->getIncomingValue(i) == Scalar) { Builder.SetInsertPoint(PH->getIncomingBlock(i)->getTerminator()); Value *Ex = Builder.CreateExtractElement(Vec, Lane); + if (MaxRequiredIntegerTy) + Ex = Builder.CreateSExt(Ex, Scalar->getType()); CSEBlocks.insert(PH->getIncomingBlock(i)); PH->setOperand(i, Ex); } @@ -2605,12 +2646,16 @@ } else { Builder.SetInsertPoint(cast(User)); Value *Ex = Builder.CreateExtractElement(Vec, Lane); + if (MaxRequiredIntegerTy) + Ex = Builder.CreateSExt(Ex, Scalar->getType()); CSEBlocks.insert(cast(User)->getParent()); User->replaceUsesOfWith(Scalar, Ex); } } else { Builder.SetInsertPoint(&F->getEntryBlock().front()); Value *Ex = Builder.CreateExtractElement(Vec, Lane); + if (MaxRequiredIntegerTy) + Ex = Builder.CreateSExt(Ex, Scalar->getType()); CSEBlocks.insert(&F->getEntryBlock()); User->replaceUsesOfWith(Scalar, Ex); } @@ -3179,7 +3224,7 @@ // If the current instruction is a load, update MaxWidth to reflect the // width of the loaded value. else if (isa(I)) - MaxWidth = std::max(MaxWidth, (unsigned)DL.getTypeSizeInBits(Ty)); + MaxWidth = std::max(MaxWidth, DL.getTypeSizeInBits(Ty)); // Otherwise, we need to visit the operands of the instruction. We only // handle the interesting cases from buildTree here. If an operand is an @@ -3206,6 +3251,66 @@ return MaxWidth; } +void BoUpSLP::computeMaxRequiredIntegerTy() { + + // If there are no external uses, there is nothing to do here. + if (ExternalUses.empty()) + return; + + // We only want to truncate the root of the expression tree to avoid having + // to deal with other external uses. The code below ensures that only the + // roots are used externally. + auto &TreeRoot = VectorizableTree[0].Scalars; + SmallPtrSet ScalarRoots(TreeRoot.begin(), TreeRoot.end()); + for (auto &EU : ExternalUses) + if (!ScalarRoots.erase(EU.Scalar)) + return; + if (!ScalarRoots.empty()) + return; + + // The maximum bit width required to represent all the instructions in the + // tree without loss of precision. It would be safe to truncate the + // expression to this width. + auto MaxBitWidth = 0u; + + // Look at each entry in the tree. + for (auto &Entry : VectorizableTree) { + + // Get a representative value for the vectorizable bundle. All values in + // Entry.Scalars should be isomorphic. + auto *Scalar = Entry.Scalars[0]; + + // We will rely on InstCombine to rewrite the expression in the narrower + // type. However, InstCombine only rewrites single-use values. If the + // scalar is used more than once, give up. + if (!Scalar->hasOneUse()) + return; + + // We only compute smaller integer types. If the scalar has a different + // type, give up. + auto *IT = dyn_cast(Scalar->getType()); + if (!IT) + return; + + // Compute the maximum bit width required to store the scalar. We use + // ValueTracking to compute the number of high-order bits we can truncate. + // We then round up to the next power-of-two. + auto &DL = F->getParent()->getDataLayout(); + auto NumSignBits = ComputeNumSignBits(Scalar, DL, 0, AC, 0, DT); + auto NumTypeBits = IT->getBitWidth(); + auto BitWidth = std::max(NumTypeBits - NumSignBits, 8u); + if (!isPowerOf2_64(BitWidth)) + BitWidth = NextPowerOf2(BitWidth); + MaxBitWidth = std::max(BitWidth, MaxBitWidth); + } + + // If the maximum bit width we compute is less than the with of the roots' + // type, we can proceed with the narrowing. Otherwise, do nothing. + auto *RootIT = cast(TreeRoot[0]->getType()); + if (MaxBitWidth > 0 && MaxBitWidth < RootIT->getBitWidth()) + MaxRequiredIntegerTy = IntegerType::get(F->getContext(), MaxBitWidth); +} + /// The SLPVectorizer Pass. struct SLPVectorizer : public FunctionPass { typedef SmallVector StoreList; @@ -3413,6 +3518,7 @@ ArrayRef Operands = Chain.slice(i, VF); R.buildTree(Operands); + R.computeMaxRequiredIntegerTy(); int Cost = R.getTreeCost(); @@ -3603,6 +3709,7 @@ Value *ReorderedOps[] = { Ops[1], Ops[0] }; R.buildTree(ReorderedOps, None); } + R.computeMaxRequiredIntegerTy(); int Cost = R.getTreeCost(); if (Cost < -SLPCostThreshold) { @@ -3869,6 +3976,7 @@ for (; i < NumReducedVals - ReduxWidth + 1; i += ReduxWidth) { V.buildTree(makeArrayRef(&ReducedVals[i], ReduxWidth), ReductionOps); + V.computeMaxRequiredIntegerTy(); // Estimate cost. int Cost = V.getTreeCost() + getReductionCost(TTI, ReducedVals[i]); Index: test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll =================================================================== --- test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll +++ test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll @@ -1,24 +1,20 @@ -; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s +; RUN: opt -S -slp-vectorizer -slp-threshold=-12 -dce -instcombine < %s | FileCheck %s target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" target triple = "aarch64--linux-gnu" -; These tests check that we vectorize the index calculations in the -; gather-reduce pattern shown below. We check cases having i32 and i64 -; subtraction. -; ; int gather_reduce_8x16(short *a, short *b, short *g, int n) { ; int sum = 0; ; for (int i = 0; i < n ; ++i) { -; sum += g[*a++ - *b++]; sum += g[*a++ - *b++]; -; sum += g[*a++ - *b++]; sum += g[*a++ - *b++]; -; sum += g[*a++ - *b++]; sum += g[*a++ - *b++]; -; sum += g[*a++ - *b++]; sum += g[*a++ - *b++]; +; sum += g[*a++ - b[0]]; sum += g[*a++ - b[4]]; +; sum += g[*a++ - b[1]]; sum += g[*a++ - b[5]]; +; sum += g[*a++ - b[2]]; sum += g[*a++ - b[6]]; +; sum += g[*a++ - b[3]]; sum += g[*a++ - b[7]]; ; } ; return sum; ; } -; CHECK-LABEL: @gather_reduce_8x16_i32 +; CHECK-LABEL: @gather_reduce_8x16 ; ; CHECK: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16> ; CHECK: zext <8 x i16> [[L]] to <8 x i32> @@ -26,129 +22,7 @@ ; CHECK: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]] ; CHECK: sext i32 [[X]] to i64 ; -define i32 @gather_reduce_8x16_i32(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) { -entry: - %cmp.99 = icmp sgt i32 %n, 0 - br i1 %cmp.99, label %for.body.preheader, label %for.cond.cleanup - -for.body.preheader: - br label %for.body - -for.cond.cleanup.loopexit: - br label %for.cond.cleanup - -for.cond.cleanup: - %sum.0.lcssa = phi i32 [ 0, %entry ], [ %add66, %for.cond.cleanup.loopexit ] - ret i32 %sum.0.lcssa - -for.body: - %i.0103 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ] - %sum.0102 = phi i32 [ %add66, %for.body ], [ 0, %for.body.preheader ] - %a.addr.0101 = phi i16* [ %incdec.ptr58, %for.body ], [ %a, %for.body.preheader ] - %b.addr.0100 = phi i16* [ %incdec.ptr60, %for.body ], [ %b, %for.body.preheader ] - %incdec.ptr = getelementptr inbounds i16, i16* %a.addr.0101, i64 1 - %0 = load i16, i16* %a.addr.0101, align 2 - %conv = zext i16 %0 to i32 - %incdec.ptr1 = getelementptr inbounds i16, i16* %b.addr.0100, i64 1 - %1 = load i16, i16* %b.addr.0100, align 2 - %conv2 = zext i16 %1 to i32 - %sub = sub nsw i32 %conv, %conv2 - %arrayidx = getelementptr inbounds i16, i16* %g, i32 %sub - %2 = load i16, i16* %arrayidx, align 2 - %conv3 = zext i16 %2 to i32 - %add = add nsw i32 %conv3, %sum.0102 - %incdec.ptr4 = getelementptr inbounds i16, i16* %a.addr.0101, i64 2 - %3 = load i16, i16* %incdec.ptr, align 2 - %conv5 = zext i16 %3 to i32 - %incdec.ptr6 = getelementptr inbounds i16, i16* %b.addr.0100, i64 2 - %4 = load i16, i16* %incdec.ptr1, align 2 - %conv7 = zext i16 %4 to i32 - %sub8 = sub nsw i32 %conv5, %conv7 - %arrayidx10 = getelementptr inbounds i16, i16* %g, i32 %sub8 - %5 = load i16, i16* %arrayidx10, align 2 - %conv11 = zext i16 %5 to i32 - %add12 = add nsw i32 %add, %conv11 - %incdec.ptr13 = getelementptr inbounds i16, i16* %a.addr.0101, i64 3 - %6 = load i16, i16* %incdec.ptr4, align 2 - %conv14 = zext i16 %6 to i32 - %incdec.ptr15 = getelementptr inbounds i16, i16* %b.addr.0100, i64 3 - %7 = load i16, i16* %incdec.ptr6, align 2 - %conv16 = zext i16 %7 to i32 - %sub17 = sub nsw i32 %conv14, %conv16 - %arrayidx19 = getelementptr inbounds i16, i16* %g, i32 %sub17 - %8 = load i16, i16* %arrayidx19, align 2 - %conv20 = zext i16 %8 to i32 - %add21 = add nsw i32 %add12, %conv20 - %incdec.ptr22 = getelementptr inbounds i16, i16* %a.addr.0101, i64 4 - %9 = load i16, i16* %incdec.ptr13, align 2 - %conv23 = zext i16 %9 to i32 - %incdec.ptr24 = getelementptr inbounds i16, i16* %b.addr.0100, i64 4 - %10 = load i16, i16* %incdec.ptr15, align 2 - %conv25 = zext i16 %10 to i32 - %sub26 = sub nsw i32 %conv23, %conv25 - %arrayidx28 = getelementptr inbounds i16, i16* %g, i32 %sub26 - %11 = load i16, i16* %arrayidx28, align 2 - %conv29 = zext i16 %11 to i32 - %add30 = add nsw i32 %add21, %conv29 - %incdec.ptr31 = getelementptr inbounds i16, i16* %a.addr.0101, i64 5 - %12 = load i16, i16* %incdec.ptr22, align 2 - %conv32 = zext i16 %12 to i32 - %incdec.ptr33 = getelementptr inbounds i16, i16* %b.addr.0100, i64 5 - %13 = load i16, i16* %incdec.ptr24, align 2 - %conv34 = zext i16 %13 to i32 - %sub35 = sub nsw i32 %conv32, %conv34 - %arrayidx37 = getelementptr inbounds i16, i16* %g, i32 %sub35 - %14 = load i16, i16* %arrayidx37, align 2 - %conv38 = zext i16 %14 to i32 - %add39 = add nsw i32 %add30, %conv38 - %incdec.ptr40 = getelementptr inbounds i16, i16* %a.addr.0101, i64 6 - %15 = load i16, i16* %incdec.ptr31, align 2 - %conv41 = zext i16 %15 to i32 - %incdec.ptr42 = getelementptr inbounds i16, i16* %b.addr.0100, i64 6 - %16 = load i16, i16* %incdec.ptr33, align 2 - %conv43 = zext i16 %16 to i32 - %sub44 = sub nsw i32 %conv41, %conv43 - %arrayidx46 = getelementptr inbounds i16, i16* %g, i32 %sub44 - %17 = load i16, i16* %arrayidx46, align 2 - %conv47 = zext i16 %17 to i32 - %add48 = add nsw i32 %add39, %conv47 - %incdec.ptr49 = getelementptr inbounds i16, i16* %a.addr.0101, i64 7 - %18 = load i16, i16* %incdec.ptr40, align 2 - %conv50 = zext i16 %18 to i32 - %incdec.ptr51 = getelementptr inbounds i16, i16* %b.addr.0100, i64 7 - %19 = load i16, i16* %incdec.ptr42, align 2 - %conv52 = zext i16 %19 to i32 - %sub53 = sub nsw i32 %conv50, %conv52 - %arrayidx55 = getelementptr inbounds i16, i16* %g, i32 %sub53 - %20 = load i16, i16* %arrayidx55, align 2 - %conv56 = zext i16 %20 to i32 - %add57 = add nsw i32 %add48, %conv56 - %incdec.ptr58 = getelementptr inbounds i16, i16* %a.addr.0101, i64 8 - %21 = load i16, i16* %incdec.ptr49, align 2 - %conv59 = zext i16 %21 to i32 - %incdec.ptr60 = getelementptr inbounds i16, i16* %b.addr.0100, i64 8 - %22 = load i16, i16* %incdec.ptr51, align 2 - %conv61 = zext i16 %22 to i32 - %sub62 = sub nsw i32 %conv59, %conv61 - %arrayidx64 = getelementptr inbounds i16, i16* %g, i32 %sub62 - %23 = load i16, i16* %arrayidx64, align 2 - %conv65 = zext i16 %23 to i32 - %add66 = add nsw i32 %add57, %conv65 - %inc = add nuw nsw i32 %i.0103, 1 - %exitcond = icmp eq i32 %inc, %n - br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body -} - -; CHECK-LABEL: @gather_reduce_8x16_i64 -; -; CHECK-NOT: load <8 x i16> -; -; FIXME: We are currently unable to vectorize the case with i64 subtraction -; because the zero extensions are too expensive. The solution here is to -; convert the i64 subtractions to i32 subtractions during vectorization. -; This would then match the case above. -; -define i32 @gather_reduce_8x16_i64(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) { +define i32 @gather_reduce_8x16(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) { entry: %cmp.99 = icmp sgt i32 %n, 0 br i1 %cmp.99, label %for.body.preheader, label %for.cond.cleanup @@ -167,12 +41,11 @@ %i.0103 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ] %sum.0102 = phi i32 [ %add66, %for.body ], [ 0, %for.body.preheader ] %a.addr.0101 = phi i16* [ %incdec.ptr58, %for.body ], [ %a, %for.body.preheader ] - %b.addr.0100 = phi i16* [ %incdec.ptr60, %for.body ], [ %b, %for.body.preheader ] %incdec.ptr = getelementptr inbounds i16, i16* %a.addr.0101, i64 1 %0 = load i16, i16* %a.addr.0101, align 2 %conv = zext i16 %0 to i64 - %incdec.ptr1 = getelementptr inbounds i16, i16* %b.addr.0100, i64 1 - %1 = load i16, i16* %b.addr.0100, align 2 + %incdec.ptr1 = getelementptr inbounds i16, i16* %b, i64 1 + %1 = load i16, i16* %b, align 2 %conv2 = zext i16 %1 to i64 %sub = sub nsw i64 %conv, %conv2 %arrayidx = getelementptr inbounds i16, i16* %g, i64 %sub @@ -182,7 +55,7 @@ %incdec.ptr4 = getelementptr inbounds i16, i16* %a.addr.0101, i64 2 %3 = load i16, i16* %incdec.ptr, align 2 %conv5 = zext i16 %3 to i64 - %incdec.ptr6 = getelementptr inbounds i16, i16* %b.addr.0100, i64 2 + %incdec.ptr6 = getelementptr inbounds i16, i16* %b, i64 2 %4 = load i16, i16* %incdec.ptr1, align 2 %conv7 = zext i16 %4 to i64 %sub8 = sub nsw i64 %conv5, %conv7 @@ -193,7 +66,7 @@ %incdec.ptr13 = getelementptr inbounds i16, i16* %a.addr.0101, i64 3 %6 = load i16, i16* %incdec.ptr4, align 2 %conv14 = zext i16 %6 to i64 - %incdec.ptr15 = getelementptr inbounds i16, i16* %b.addr.0100, i64 3 + %incdec.ptr15 = getelementptr inbounds i16, i16* %b, i64 3 %7 = load i16, i16* %incdec.ptr6, align 2 %conv16 = zext i16 %7 to i64 %sub17 = sub nsw i64 %conv14, %conv16 @@ -204,7 +77,7 @@ %incdec.ptr22 = getelementptr inbounds i16, i16* %a.addr.0101, i64 4 %9 = load i16, i16* %incdec.ptr13, align 2 %conv23 = zext i16 %9 to i64 - %incdec.ptr24 = getelementptr inbounds i16, i16* %b.addr.0100, i64 4 + %incdec.ptr24 = getelementptr inbounds i16, i16* %b, i64 4 %10 = load i16, i16* %incdec.ptr15, align 2 %conv25 = zext i16 %10 to i64 %sub26 = sub nsw i64 %conv23, %conv25 @@ -215,7 +88,7 @@ %incdec.ptr31 = getelementptr inbounds i16, i16* %a.addr.0101, i64 5 %12 = load i16, i16* %incdec.ptr22, align 2 %conv32 = zext i16 %12 to i64 - %incdec.ptr33 = getelementptr inbounds i16, i16* %b.addr.0100, i64 5 + %incdec.ptr33 = getelementptr inbounds i16, i16* %b, i64 5 %13 = load i16, i16* %incdec.ptr24, align 2 %conv34 = zext i16 %13 to i64 %sub35 = sub nsw i64 %conv32, %conv34 @@ -226,7 +99,7 @@ %incdec.ptr40 = getelementptr inbounds i16, i16* %a.addr.0101, i64 6 %15 = load i16, i16* %incdec.ptr31, align 2 %conv41 = zext i16 %15 to i64 - %incdec.ptr42 = getelementptr inbounds i16, i16* %b.addr.0100, i64 6 + %incdec.ptr42 = getelementptr inbounds i16, i16* %b, i64 6 %16 = load i16, i16* %incdec.ptr33, align 2 %conv43 = zext i16 %16 to i64 %sub44 = sub nsw i64 %conv41, %conv43 @@ -237,7 +110,7 @@ %incdec.ptr49 = getelementptr inbounds i16, i16* %a.addr.0101, i64 7 %18 = load i16, i16* %incdec.ptr40, align 2 %conv50 = zext i16 %18 to i64 - %incdec.ptr51 = getelementptr inbounds i16, i16* %b.addr.0100, i64 7 + %incdec.ptr51 = getelementptr inbounds i16, i16* %b, i64 7 %19 = load i16, i16* %incdec.ptr42, align 2 %conv52 = zext i16 %19 to i64 %sub53 = sub nsw i64 %conv50, %conv52 @@ -248,7 +121,6 @@ %incdec.ptr58 = getelementptr inbounds i16, i16* %a.addr.0101, i64 8 %21 = load i16, i16* %incdec.ptr49, align 2 %conv59 = zext i16 %21 to i64 - %incdec.ptr60 = getelementptr inbounds i16, i16* %b.addr.0100, i64 8 %22 = load i16, i16* %incdec.ptr51, align 2 %conv61 = zext i16 %22 to i64 %sub62 = sub nsw i64 %conv59, %conv61