Index: llvm/trunk/lib/Transforms/InstCombine/InstCombineVectorOps.cpp =================================================================== --- llvm/trunk/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ llvm/trunk/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -352,6 +352,48 @@ return false; } +/// If we have insertion into a vector that is wider than the vector that we +/// are extracting from, try to widen the source vector to allow a single +/// shufflevector to replace one or more insert/extract pairs. +static void replaceExtractElements(InsertElementInst *InsElt, + ExtractElementInst *ExtElt, + InstCombiner &IC) { + VectorType *InsVecType = InsElt->getType(); + VectorType *ExtVecType = ExtElt->getVectorOperandType(); + unsigned NumInsElts = InsVecType->getVectorNumElements(); + unsigned NumExtElts = ExtVecType->getVectorNumElements(); + + // The inserted-to vector must be wider than the extracted-from vector. + if (InsVecType->getElementType() != ExtVecType->getElementType() || + NumExtElts >= NumInsElts) + return; + + // Create a shuffle mask to widen the extended-from vector using undefined + // values. The mask selects all of the values of the original vector followed + // by as many undefined values as needed to create a vector of the same length + // as the inserted-to vector. + SmallVector ExtendMask; + IntegerType *IntType = Type::getInt32Ty(InsElt->getContext()); + for (unsigned i = 0; i < NumExtElts; ++i) + ExtendMask.push_back(ConstantInt::get(IntType, i)); + for (unsigned i = NumExtElts; i < NumInsElts; ++i) + ExtendMask.push_back(UndefValue::get(IntType)); + + Value *ExtVecOp = ExtElt->getVectorOperand(); + auto *WideVec = new ShuffleVectorInst(ExtVecOp, UndefValue::get(ExtVecType), + ConstantVector::get(ExtendMask)); + + // Replace all extracts from the original narrow vector with extracts from + // the new wide vector. + WideVec->insertBefore(ExtElt); + for (User *U : ExtVecOp->users()) { + if (ExtractElementInst *OldExt = dyn_cast(U)) { + auto *NewExt = ExtractElementInst::Create(WideVec, OldExt->getOperand(1)); + NewExt->insertAfter(WideVec); + IC.ReplaceInstUsesWith(*OldExt, NewExt); + } + } +} /// We are building a shuffle to create V, which is a sequence of insertelement, /// extractelement pairs. If PermittedRHS is set, then we must either use it or @@ -365,7 +407,8 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl &Mask, - Value *PermittedRHS) { + Value *PermittedRHS, + InstCombiner &IC) { assert(V->getType()->isVectorTy() && "Invalid shuffle!"); unsigned NumElts = cast(V->getType())->getNumElements(); @@ -396,10 +439,14 @@ // otherwise we'd end up with a shuffle of three inputs. if (EI->getOperand(0) == PermittedRHS || PermittedRHS == nullptr) { Value *RHS = EI->getOperand(0); - ShuffleOps LR = collectShuffleElements(VecOp, Mask, RHS); + ShuffleOps LR = collectShuffleElements(VecOp, Mask, RHS, IC); assert(LR.second == nullptr || LR.second == RHS); if (LR.first->getType() != RHS->getType()) { + // Although we are giving up for now, see if we can create extracts + // that match the inserts for another round of combining. + replaceExtractElements(IEI, EI, IC); + // We tried our best, but we can't find anything compatible with RHS // further up the chain. Return a trivial shuffle. for (unsigned i = 0; i < NumElts; ++i) @@ -512,7 +559,7 @@ // (and any insertelements it points to), into one big shuffle. if (!IE.hasOneUse() || !isa(IE.user_back())) { SmallVector Mask; - ShuffleOps LR = collectShuffleElements(&IE, Mask, nullptr); + ShuffleOps LR = collectShuffleElements(&IE, Mask, nullptr, *this); // The proposed shuffle may be trivial, in which case we shouldn't // perform the combine. Index: llvm/trunk/test/Transforms/InstCombine/insert-extract-shuffle.ll =================================================================== --- llvm/trunk/test/Transforms/InstCombine/insert-extract-shuffle.ll +++ llvm/trunk/test/Transforms/InstCombine/insert-extract-shuffle.ll @@ -26,8 +26,8 @@ define <2 x i64> @test_vcopyq_lane_p64(<2 x i64> %a, <1 x i64> %b) { ; CHECK-LABEL: @test_vcopyq_lane_p64 -; CHECK-NEXT: extractelement -; CHECK-NEXT: insertelement +; CHECK-NEXT: %[[WIDEVEC:.*]] = shufflevector <1 x i64> %b, <1 x i64> undef, <2 x i32> +; CHECK-NEXT: shufflevector <2 x i64> %a, <2 x i64> %[[WIDEVEC]], <2 x i32> ; CHECK-NEXT: ret <2 x i64> %res %elt = extractelement <1 x i64> %b, i32 0 %res = insertelement <2 x i64> %a, i64 %elt, i32 1 @@ -38,10 +38,8 @@ define <4 x float> @widen_extract2(<4 x float> %ins, <2 x float> %ext) { ; CHECK-LABEL: @widen_extract2( -; CHECK-NEXT: extractelement -; CHECK-NEXT: extractelement -; CHECK-NEXT: insertelement -; CHECK-NEXT: insertelement +; CHECK-NEXT: %[[WIDEVEC:.*]] = shufflevector <2 x float> %ext, <2 x float> undef, <4 x i32> +; CHECK-NEXT: shufflevector <4 x float> %ins, <4 x float> %[[WIDEVEC]], <4 x i32> ; CHECK-NEXT: ret <4 x float> %i2 %e1 = extractelement <2 x float> %ext, i32 0 %e2 = extractelement <2 x float> %ext, i32 1 @@ -52,12 +50,8 @@ define <4 x float> @widen_extract3(<4 x float> %ins, <3 x float> %ext) { ; CHECK-LABEL: @widen_extract3( -; CHECK-NEXT: extractelement -; CHECK-NEXT: extractelement -; CHECK-NEXT: extractelement -; CHECK-NEXT: insertelement -; CHECK-NEXT: insertelement -; CHECK-NEXT: insertelement +; CHECK-NEXT: %[[WIDEVEC:.*]] = shufflevector <3 x float> %ext, <3 x float> undef, <4 x i32> +; CHECK-NEXT: shufflevector <4 x float> %ins, <4 x float> %[[WIDEVEC]], <4 x i32> ; CHECK-NEXT: ret <4 x float> %i3 %e1 = extractelement <3 x float> %ext, i32 0 %e2 = extractelement <3 x float> %ext, i32 1 @@ -68,10 +62,10 @@ ret <4 x float> %i3 } -define <8 x float> @too_wide(<8 x float> %ins, <2 x float> %ext) { -; CHECK-LABEL: @too_wide( -; CHECK-NEXT: extractelement -; CHECK-NEXT: insertelement +define <8 x float> @widen_extract4(<8 x float> %ins, <2 x float> %ext) { +; CHECK-LABEL: @widen_extract4( +; CHECK-NEXT: %[[WIDEVEC:.*]] = shufflevector <2 x float> %ext, <2 x float> undef, <8 x i32> +; CHECK-NEXT: shufflevector <8 x float> %ins, <8 x float> %[[WIDEVEC]], <8 x i32> ; CHECK-NEXT: ret <8 x float> %i1 %e1 = extractelement <2 x float> %ext, i32 0 %i1 = insertelement <8 x float> %ins, float %e1, i32 2