Index: lib/Transforms/InstCombine/InstCombineVectorOps.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -451,6 +451,7 @@ static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl &Mask, Value *PermittedRHS, + Optional ExtOpc, InstCombiner &IC) { assert(V->getType()->isVectorTy() && "Invalid shuffle!"); unsigned NumElts = V->getType()->getVectorNumElements(); @@ -472,6 +473,11 @@ Value *ScalarOp = IEI->getOperand(1); Value *IdxOp = IEI->getOperand(2); + if(ExtOpc) + if (CastInst *ZI = dyn_cast(ScalarOp)) + if (ZI->getOpcode() == *ExtOpc) + ScalarOp = ZI->getOperand(0); + if (ExtractElementInst *EI = dyn_cast(ScalarOp)) { if (isa(EI->getOperand(1)) && isa(IdxOp)) { unsigned ExtractedIdx = @@ -482,7 +488,7 @@ // otherwise we'd end up with a shuffle of three inputs. if (EI->getOperand(0) == PermittedRHS || PermittedRHS == nullptr) { Value *RHS = EI->getOperand(0); - ShuffleOps LR = collectShuffleElements(VecOp, Mask, RHS, IC); + ShuffleOps LR = collectShuffleElements(VecOp, Mask, RHS, ExtOpc, IC); assert(LR.second == nullptr || LR.second == RHS); if (LR.first->getType() != RHS->getType()) { @@ -791,6 +797,15 @@ if (isa(ScalarOp) || isa(IdxOp)) replaceInstUsesWith(IE, VecOp); + Optional ExtOpc = None; + if (CastInst *ZI = dyn_cast(ScalarOp)) { + if (ZI->getOpcode() == Instruction::ZExt || + ZI->getOpcode() == Instruction::SExt) { + ScalarOp = ZI->getOperand(0); + ExtOpc = ZI->getOpcode(); + } + } + // If the inserted element was extracted from some other vector, and if the // indexes are constant, try to turn this into a shufflevector operation. if (ExtractElementInst *EI = dyn_cast(ScalarOp)) { @@ -817,7 +832,7 @@ // (and any insertelements it points to), into one big shuffle. if (!IE.hasOneUse() || !isa(IE.user_back())) { SmallVector Mask; - ShuffleOps LR = collectShuffleElements(&IE, Mask, nullptr, *this); + ShuffleOps LR = collectShuffleElements(&IE, Mask, nullptr, ExtOpc, *this); // The proposed shuffle may be trivial, in which case we shouldn't // perform the combine. @@ -825,8 +840,16 @@ // We now have a shuffle of LHS, RHS, Mask. if (LR.second == nullptr) LR.second = UndefValue::get(LR.first->getType()); - return new ShuffleVectorInst(LR.first, LR.second, + + Instruction *SVI = new ShuffleVectorInst(LR.first, LR.second, ConstantVector::get(Mask)); + + if (ExtOpc) { + SVI->insertBefore(&IE); + SVI = CastInst::Create(*ExtOpc, SVI, IE.getType()); + } + + return SVI; } } } Index: test/Transforms/InstCombine/insert-extract-shuffle.ll =================================================================== --- test/Transforms/InstCombine/insert-extract-shuffle.ll +++ test/Transforms/InstCombine/insert-extract-shuffle.ll @@ -283,3 +283,51 @@ %ret = select i1 %e, <4 x i32> %b, <4 x i32> zeroinitializer ret <4 x i32> %ret } + +define <4 x i32> @test3(<8 x i16> %in) { +; CHECK-LABEL: @test3( +; CHECK-NEXT: [[VEC_3:%.*]] = shufflevector <8 x i16> %in, <8 x i16> undef, <4 x i32> +; CHECK-NEXT: [[ZEXT:%.*]] = zext <4 x i16> [[VEC_3]] to <4 x i32> +; CHECK-NEXT: ret <4 x i32> [[ZEXT]] +; + %elt0e = extractelement <8 x i16> %in, i32 3 + %elt1e = extractelement <8 x i16> %in, i32 1 + %elt2e = extractelement <8 x i16> %in, i32 0 + %elt3e = extractelement <8 x i16> %in, i32 3 + + %elt0 = zext i16 %elt0e to i32 + %elt1 = zext i16 %elt1e to i32 + %elt2 = zext i16 %elt2e to i32 + %elt3 = zext i16 %elt3e to i32 + + %vec.0 = insertelement <4 x i32> undef, i32 %elt0, i32 0 + %vec.1 = insertelement <4 x i32> %vec.0, i32 %elt1, i32 1 + %vec.2 = insertelement <4 x i32> %vec.1, i32 %elt2, i32 2 + %vec.3 = insertelement <4 x i32> %vec.2, i32 %elt3, i32 3 + + ret <4 x i32> %vec.3 +} + +define <4 x i32> @test4(<8 x i16> %in) { +; CHECK-LABEL: @test4( +; CHECK-NEXT: [[VEC_3:%.*]] = shufflevector <8 x i16> %in, <8 x i16> undef, <4 x i32> +; CHECK-NEXT: [[ZEXT:%.*]] = sext <4 x i16> [[VEC_3]] to <4 x i32> +; CHECK-NEXT: ret <4 x i32> [[ZEXT]] +; + %elt0e = extractelement <8 x i16> %in, i32 3 + %elt1e = extractelement <8 x i16> %in, i32 1 + %elt2e = extractelement <8 x i16> %in, i32 0 + %elt3e = extractelement <8 x i16> %in, i32 3 + + %elt0 = sext i16 %elt0e to i32 + %elt1 = sext i16 %elt1e to i32 + %elt2 = sext i16 %elt2e to i32 + %elt3 = sext i16 %elt3e to i32 + + %vec.0 = insertelement <4 x i32> undef, i32 %elt0, i32 0 + %vec.1 = insertelement <4 x i32> %vec.0, i32 %elt1, i32 1 + %vec.2 = insertelement <4 x i32> %vec.1, i32 %elt2, i32 2 + %vec.3 = insertelement <4 x i32> %vec.2, i32 %elt3, i32 3 + + ret <4 x i32> %vec.3 +}