Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -910,8 +910,11 @@ IRBuilder<> Builder; /// A map of scalar integer values to the smallest bit width with which they - /// can legally be represented. - MapVector MinBWs; + /// can legally be represented. The values map to (width, signed) pairs, + /// where "width" indicates the minimum bit width and "signed" is True if the + /// value must be signed-extended, rather than zero-extended, back to its + /// original width. + MapVector> MinBWs; }; } // end namespace llvm @@ -1572,8 +1575,8 @@ // If we have computed a smaller type for the expression, update VecTy so // that the costs will be accurate. if (MinBWs.count(VL[0])) - VecTy = VectorType::get(IntegerType::get(F->getContext(), MinBWs[VL[0]]), - VL.size()); + VecTy = VectorType::get( + IntegerType::get(F->getContext(), MinBWs[VL[0]].first), VL.size()); if (E->NeedToGather) { if (allConstant(VL)) @@ -1929,10 +1932,12 @@ auto *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth); auto *ScalarRoot = VectorizableTree[0].Scalars[0]; if (MinBWs.count(ScalarRoot)) { - auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot]); + auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot].first); + auto ExTy = + MinBWs[ScalarRoot].second ? Instruction::SExt : Instruction::ZExt; VecTy = VectorType::get(MinTy, BundleWidth); - ExtractCost += TTI->getExtractWithExtendCost( - Instruction::SExt, EU.Scalar->getType(), VecTy, EU.Lane); + ExtractCost += TTI->getExtractWithExtendCost(ExTy, EU.Scalar->getType(), + VecTy, EU.Lane); } else { ExtractCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane); @@ -2718,7 +2723,7 @@ if (auto *I = dyn_cast(VectorRoot)) Builder.SetInsertPoint(&*++BasicBlock::iterator(I)); auto BundleWidth = VectorizableTree[0].Scalars.size(); - auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot]); + auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot].first); auto *VecTy = VectorType::get(MinTy, BundleWidth); auto *Trunc = Builder.CreateTrunc(VectorRoot, VecTy); VectorizableTree[0].VectorizedValue = Trunc; @@ -2726,6 +2731,16 @@ DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() << " values .\n"); + // If necessary, sign-extend or zero-extend ScalarRoot to the larger type + // specified by ScalarType. + auto extend = [&](Value *ScalarRoot, Value *Ex, Type *ScalarType) { + if (!MinBWs.count(ScalarRoot)) + return Ex; + if (MinBWs[ScalarRoot].second) + return Builder.CreateSExt(Ex, ScalarType); + return Builder.CreateZExt(Ex, ScalarType); + }; + // Extract all of the elements with the external uses. for (const auto &ExternalUse : ExternalUses) { Value *Scalar = ExternalUse.Scalar; @@ -2760,8 +2775,7 @@ Builder.SetInsertPoint(PH->getIncomingBlock(i)->getTerminator()); } Value *Ex = Builder.CreateExtractElement(Vec, Lane); - if (MinBWs.count(ScalarRoot)) - Ex = Builder.CreateSExt(Ex, Scalar->getType()); + Ex = extend(ScalarRoot, Ex, Scalar->getType()); CSEBlocks.insert(PH->getIncomingBlock(i)); PH->setOperand(i, Ex); } @@ -2769,16 +2783,14 @@ } else { Builder.SetInsertPoint(cast(User)); Value *Ex = Builder.CreateExtractElement(Vec, Lane); - if (MinBWs.count(ScalarRoot)) - Ex = Builder.CreateSExt(Ex, Scalar->getType()); + Ex = extend(ScalarRoot, Ex, Scalar->getType()); CSEBlocks.insert(cast(User)->getParent()); User->replaceUsesOfWith(Scalar, Ex); } } else { Builder.SetInsertPoint(&F->getEntryBlock().front()); Value *Ex = Builder.CreateExtractElement(Vec, Lane); - if (MinBWs.count(ScalarRoot)) - Ex = Builder.CreateSExt(Ex, Scalar->getType()); + Ex = extend(ScalarRoot, Ex, Scalar->getType()); CSEBlocks.insert(&F->getEntryBlock()); User->replaceUsesOfWith(Scalar, Ex); } @@ -3499,6 +3511,20 @@ Mask.getBitWidth() - Mask.countLeadingZeros(), MaxBitWidth); } + // True if the roots must be sign-extended back to their original type, + // rather than zero-extend. We known that if the leading zeros are not + // demanded, we can safely zero-extend. So we initialize IsSigned to False. + bool IsSigned = false; + + // Wrapper around llvm::ComputeSignBit to determine if the sign bit of a + // value is known to be zero. + auto isSignBitKnownZero = [&](Value *V) -> bool { + bool KnownZero = false; + bool KnownOne = false; + ComputeSignBit(V, KnownZero, KnownOne, *DL); + return KnownZero; + }; + // If all the bits of the roots are demanded, we can try a little harder to // compute a narrower type. This can happen, for example, if the roots are // getelementptr indices. InstCombine promotes these indices to the pointer @@ -3510,11 +3536,30 @@ // compute the number of high-order bits we can truncate. if (MaxBitWidth == DL->getTypeSizeInBits(TreeRoot[0]->getType())) { MaxBitWidth = 8u; + + // Determine if the sign bit of all the roots is known to be zero. If not, + // IsSigned is set to True. + for (auto *Root : TreeRoot) { + IsSigned = !isSignBitKnownZero(Root); + if (IsSigned) + break; + } + + // Determine the maximum number of bits required to store the scalar + // values. for (auto *Scalar : ToDemote) { auto NumSignBits = ComputeNumSignBits(Scalar, *DL, 0, AC, 0, DT); auto NumTypeBits = DL->getTypeSizeInBits(Scalar->getType()); MaxBitWidth = std::max(NumTypeBits - NumSignBits, MaxBitWidth); } + + // If we can't prove that the sign bit is zero, we must add one to the + // maximum bit width to account for the unknown sign bit. This preserves + // the existing sign bit so we can safely sign-extend the root back to the + // original type. Otherwise, if we know the sign bit is zero, we will + // zero-extend the root instead. + if (IsSigned) + ++MaxBitWidth; } // Round MaxBitWidth up to the next power-of-two. @@ -3534,7 +3579,7 @@ // Finally, map the values we can demote to the maximum bit with we computed. for (auto *Scalar : ToDemote) - MinBWs[Scalar] = MaxBitWidth; + MinBWs[Scalar] = std::make_pair(MaxBitWidth, IsSigned); } namespace { Index: test/Transforms/SLPVectorizer/X86/minimum-sizes.ll =================================================================== --- /dev/null +++ test/Transforms/SLPVectorizer/X86/minimum-sizes.ll @@ -0,0 +1,60 @@ +; RUN: opt -S -slp-vectorizer < %s | FileCheck %s + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +; This test ensures that we do not regress due to PR31243. When computing +; minimum sizes, if we can prove the sign bit is zero, we can zero-extend the +; roots back to their original sizes. Otherwise, we have to include one extra +; bit for signedness since we will sign-extend the root instead. +; +; CHECK-LABEL: @PR31243 +; CHECK: %[[E0:.+]] = extractelement <4 x i8> {{.*}}, i32 0 +; CHECK: %[[Z0:.+]] = zext i8 %[[E0]] to i32 +; CHECK: getelementptr inbounds i8, i8* %ptr, i32 %[[Z0]] +; CHECK: %[[E1:.+]] = extractelement <4 x i8> {{.*}}, i32 1 +; CHECK: %[[Z1:.+]] = zext i8 %[[E1]] to i32 +; CHECK: getelementptr inbounds i8, i8* %ptr, i32 %[[Z1]] +; CHECK: %[[E2:.+]] = extractelement <4 x i8> {{.*}}, i32 2 +; CHECK: %[[Z2:.+]] = zext i8 %[[E2]] to i32 +; CHECK: getelementptr inbounds i8, i8* %ptr, i32 %[[Z2]] +; CHECK: %[[E3:.+]] = extractelement <4 x i8> {{.*}}, i32 3 +; CHECK: %[[Z3:.+]] = zext i8 %[[E3]] to i32 +; CHECK: getelementptr inbounds i8, i8* %ptr, i32 %[[Z3]] +; +define void @PR31243(i8 %v0, i8 %v1, i8 %v2, i8 %v3, i8* %ptr) { +entry: + %tmp00 = xor i8 0, %v0 + %tmp01 = xor i8 0, %v1 + %tmp02 = xor i8 0, %v2 + %tmp03 = xor i8 0, %v3 + %tmp04 = add i8 %tmp00, 30 + %tmp05 = add i8 %tmp01, 29 + %tmp06 = add i8 %tmp02, 48 + %tmp07 = add i8 %tmp03, 67 + %tmp08 = shl i8 %tmp00, 1 + %tmp09 = shl i8 %tmp01, 1 + %tmp10 = shl i8 %tmp02, 1 + %tmp11 = shl i8 %tmp03, 1 + %tmp12 = and i8 %tmp08, 60 + %tmp13 = and i8 %tmp09, 58 + %tmp14 = and i8 %tmp10, 96 + %tmp15 = and i8 %tmp11, -122 + %tmp16 = sub i8 %tmp04, %tmp12 + %tmp17 = sub i8 %tmp05, %tmp13 + %tmp18 = sub i8 %tmp06, %tmp14 + %tmp19 = sub i8 %tmp07, %tmp15 + %tmp20 = xor i8 %tmp16, 30 + %tmp21 = xor i8 %tmp17, 29 + %tmp22 = xor i8 %tmp18, 48 + %tmp23 = xor i8 %tmp19, 67 + %tmp24 = zext i8 %tmp20 to i32 + %tmp25 = zext i8 %tmp21 to i32 + %tmp26 = zext i8 %tmp22 to i32 + %tmp27 = zext i8 %tmp23 to i32 + %tmp28 = getelementptr inbounds i8, i8* %ptr, i32 %tmp24 + %tmp29 = getelementptr inbounds i8, i8* %ptr, i32 %tmp25 + %tmp30 = getelementptr inbounds i8, i8* %ptr, i32 %tmp26 + %tmp31 = getelementptr inbounds i8, i8* %ptr, i32 %tmp27 + ret void +}