Index: lib/Transforms/Scalar/SROA.cpp =================================================================== --- lib/Transforms/Scalar/SROA.cpp +++ lib/Transforms/Scalar/SROA.cpp @@ -2277,8 +2277,14 @@ return V; } -static Value *extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, - unsigned EndIndex, const Twine &Name) { +// \brief Extract a continuous range of elements from a vector. +// +// \param V Vector value to extract from. +// \param TargetTy Type to which the return value will be converted. Used to +// optimize the vector extraction when possible. +static Value *extractVector(const DataLayout &DL, IRBuilderTy &IRB, Value *V, + unsigned BeginIndex, unsigned EndIndex, + Type *TargetTy, const Twine &Name) { VectorType *VecTy = cast(V->getType()); unsigned NumElements = EndIndex - BeginIndex; assert(NumElements <= VecTy->getNumElements() && "Too many elements!"); @@ -2293,7 +2299,65 @@ return V; } - SmallVector Mask; + SmallVector Mask; + + // Try to cast the vector to another vector type of the same bitwidth, and + // extract an element. This will work if the vector types are compatible, and + // the begin index is aligned to a value in the casted vector type. If the + // begin index isn't aligned then we can shuffle the original vector (keeping + // the same vector type) before extracting. + // + // This code will bail out if the target type is fundamentally incompatible + // with vectors of the source type. + // + // Example of <16 x i8>, target type i32: + // Index range [4,8): v-----------v Will work. + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + // <16 x i8>: | | | | | | | | | | | | | | | | | + // <4 x i32>: | | | | | + // +-----------+-----------+-----------+-----------+ + // Index range [6,10): ^-----------^ Needs an extra shuffle. + // Target type i40: ^--------------^ Won't work, bail. + if (unsigned TargetElemBitWidth = TargetTy->getPrimitiveSizeInBits()) { + unsigned VecBitWidth = VecTy->getBitWidth(); + unsigned SrcElemBitWidth = + VecTy->getElementType()->getPrimitiveSizeInBits(); + assert(SrcElemBitWidth && "vector elements must have a bitwidth"); + unsigned SrcNumElems = VecTy->getNumElements(); + unsigned TargetNumElems = VecBitWidth / TargetElemBitWidth; + bool VecBitWidthsEqual = VecBitWidth == TargetNumElems * TargetElemBitWidth; + bool BeginIsAligned = + 0 == ((SrcElemBitWidth * BeginIndex) % TargetElemBitWidth); + if (VecBitWidthsEqual && VectorType::isValidElementType(TargetTy)) { + VectorType *CastVecTy = VectorType::get(TargetTy, TargetNumElems); + if (canConvertValue(DL, VecTy, CastVecTy)) { + if (!BeginIsAligned) { + // Shuffle the input so [0,NumElements) contains the output, and + // [NumElems,SrcNumElems) is undef. + Mask.reserve(SrcNumElems); + unsigned i = BeginIndex; + while (i != EndIndex) + Mask.push_back(IRB.getInt32(i++)); + while (i++ != SrcNumElems) + Mask.push_back(IRB.getInt32(SrcNumElems)); // undef + V = IRB.CreateShuffleVector(V, UndefValue::get(V->getType()), + ConstantVector::get(Mask), + Name + ".extract"); + DEBUG(dbgs() << " shuffle: " << *V << "\n"); + BeginIndex = 0; + } + unsigned SrcElemsPerTargetElem = TargetElemBitWidth / SrcElemBitWidth; + assert(SrcElemsPerTargetElem); + BeginIndex /= SrcElemsPerTargetElem; + V = IRB.CreateExtractElement(convertValue(DL, IRB, V, CastVecTy), + IRB.getInt32(BeginIndex), + Name + ".extract"); + DEBUG(dbgs() << " extract: " << *V << "\n"); + return V; + } + } + } + Mask.reserve(NumElements); for (unsigned i = BeginIndex; i != EndIndex; ++i) Mask.push_back(IRB.getInt32(i)); @@ -2549,13 +2613,18 @@ Pass.DeadInsts.insert(I); } - Value *rewriteVectorizedLoadInst() { + // \brief Rewrite a vector load instruction to a load followed by the + // extraction of a subset of the vector's elements. + // + // \param TargetTy Type to which the return value will be converted. Used to + // optimize the vector extraction when possible. + Value *rewriteVectorizedLoadInst(Type *TargetTy) { unsigned BeginIndex = getIndex(NewBeginOffset); unsigned EndIndex = getIndex(NewEndOffset); assert(EndIndex > BeginIndex && "Empty vector!"); Value *V = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); - return extractVector(IRB, V, BeginIndex, EndIndex, "vec"); + return extractVector(DL, IRB, V, BeginIndex, EndIndex, TargetTy, "vec"); } Value *rewriteIntegerLoad(LoadInst &LI) { @@ -2581,7 +2650,7 @@ bool IsPtrAdjusted = false; Value *V; if (VecTy) { - V = rewriteVectorizedLoadInst(); + V = rewriteVectorizedLoadInst(TargetTy); } else if (IntTy && LI.getType()->isIntegerTy()) { V = rewriteIntegerLoad(LI); } else if (NewBeginOffset == NewAllocaBeginOffset && @@ -3003,7 +3072,10 @@ Value *Src; if (VecTy && !IsWholeAlloca && !IsDest) { Src = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); - Src = extractVector(IRB, Src, BeginIndex, EndIndex, "vec"); + // FIXME: in some cases we can figure out a better target type which would + // allow generating an extract directly. + Type *TargetTy = OtherPtrTy->getPointerElementType(); + Src = extractVector(DL, IRB, Src, BeginIndex, EndIndex, TargetTy, "vec"); } else if (IntTy && !IsWholeAlloca && !IsDest) { Src = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); Src = convertValue(DL, IRB, Src, IntTy); Index: test/Transforms/SROA/vector-promotion.ll =================================================================== --- test/Transforms/SROA/vector-promotion.ll +++ test/Transforms/SROA/vector-promotion.ll @@ -623,3 +623,28 @@ ; CHECK-NEXT: ret <4 x float> %[[ret]] ret <4 x float> %vec } + +%U4xi32 = type { <4 x i32> } + +define i32 @type_pun(<16 x i8> %in) { +; Ensure that type punning using a union of vector and same-sized array +; generates an extract. +; +; CHECK-LABEL: @type_pun( +; CHECK-NOT: alloca +; CHECK-NEXT: %[[BC1:.*]] = bitcast <16 x i8> %in to <4 x i32> +; CHECK-NEXT: %[[EXT1:.*]] = extractelement <4 x i32> %[[BC1]], i32 0 +; CHECK-NEXT: %[[BC2:.*]] = bitcast <16 x i8> %in to <4 x i32> +; CHECK-NEXT: %[[EXT2:.*]] = extractelement <4 x i32> %[[BC2]], i32 2 +; CHECK-NEXT: %[[SUM:.*]] = add i32 %[[EXT1]], %[[EXT2]] +; CHECK-NEXT: ret i32 %[[SUM]] + %stack = alloca %U4xi32, align 16 + %vec = bitcast %U4xi32* %stack to <16 x i8>* + store <16 x i8> %in, <16 x i8>* %vec, align 16 + %idx1 = getelementptr inbounds %U4xi32* %stack, i32 0, i32 0, i32 0 + %elem1 = load i32* %idx1, align 4 + %idx2 = getelementptr inbounds %U4xi32* %stack, i32 0, i32 0, i32 2 + %elem2 = load i32* %idx2, align 4 + %sum = add i32 %elem1, %elem2 + ret i32 %sum +}