diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -1741,7 +1741,8 @@ return new ShuffleVectorInst(X, Y, ConstantVector::get(NewMask)); } -/// Try to replace a shuffle with an insertelement. +/// Try to replace a shuffle with an insertelement or try to replace a shuffle +/// operand with the operand of an insertelement. static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf) { Value *V0 = Shuf.getOperand(0), *V1 = Shuf.getOperand(1); SmallVector Mask = Shuf.getShuffleMask(); @@ -1753,6 +1754,31 @@ if (NumElts != (int)(V0->getType()->getVectorNumElements())) return nullptr; + // This is a specialization of a fold in SimplifyDemandedVectorElts. We may + // not be able to handle it there if the insertelement has >1 use. + // If the shuffle has an insertelement operand but does not choose the + // inserted scalar element from that value, then we can replace that shuffle + // operand with the source vector of the insertelement. + Value *X; + uint64_t IdxC; + if (match(V0, m_InsertElement(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { + // shuf (inselt X, ?, IdxC), ?, Mask --> shuf X, ?, Mask + if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; })) { + Shuf.setOperand(0, X); + return &Shuf; + } + } + if (match(V1, m_InsertElement(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { + // Offset the index constant by the vector width because we are checking for + // accesses to the 2nd vector input of the shuffle. + IdxC += NumElts; + // shuf ?, (inselt X, ?, IdxC), Mask --> shuf ?, X, Mask + if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; })) { + Shuf.setOperand(1, X); + return &Shuf; + } + } + // shuffle (insert ?, Scalar, IndexC), V1, Mask --> insert V1, Scalar, IndexC' auto isShufflingScalarIntoOp1 = [&](Value *&Scalar, ConstantInt *&IndexC) { // We need an insertelement with a constant index. diff --git a/llvm/test/Transforms/InstCombine/insert-extract-shuffle.ll b/llvm/test/Transforms/InstCombine/insert-extract-shuffle.ll --- a/llvm/test/Transforms/InstCombine/insert-extract-shuffle.ll +++ b/llvm/test/Transforms/InstCombine/insert-extract-shuffle.ll @@ -670,7 +670,7 @@ ; CHECK-LABEL: @insert_undemanded_element_op0( ; CHECK-NEXT: [[INS:%.*]] = insertelement <4 x float> [[X:%.*]], float 4.200000e+01, i32 3 ; CHECK-NEXT: call void @use(<4 x float> [[INS]]) -; CHECK-NEXT: [[S:%.*]] = shufflevector <4 x float> [[INS]], <4 x float> [[Y:%.*]], <4 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <4 x float> [[X]], <4 x float> [[Y:%.*]], <4 x i32> ; CHECK-NEXT: ret <4 x float> [[S]] ; %ins = insertelement <4 x float> %x, float 42.0, i32 3 @@ -683,7 +683,7 @@ ; CHECK-LABEL: @insert_undemanded_element_op1( ; CHECK-NEXT: [[INS:%.*]] = insertelement <4 x float> [[X:%.*]], float 4.200000e+01, i32 3 ; CHECK-NEXT: call void @use(<4 x float> [[INS]]) -; CHECK-NEXT: [[S:%.*]] = shufflevector <4 x float> [[Y:%.*]], <4 x float> [[INS]], <4 x i32> +; CHECK-NEXT: [[S:%.*]] = shufflevector <4 x float> [[Y:%.*]], <4 x float> [[X]], <4 x i32> ; CHECK-NEXT: ret <4 x float> [[S]] ; %ins = insertelement <4 x float> %x, float 42.0, i32 3 @@ -692,6 +692,8 @@ ret <4 x float> %s } +; Negative test - shuffle chooses the inserted constant. + define <4 x float> @insert_demanded_element_op0(<4 x float> %x, <4 x float> %y) { ; CHECK-LABEL: @insert_demanded_element_op0( ; CHECK-NEXT: [[INS:%.*]] = insertelement <4 x float> [[X:%.*]], float 4.200000e+01, i32 3 @@ -705,6 +707,8 @@ ret <4 x float> %s } +; Negative test - shuffle chooses the inserted constant. + define <4 x float> @insert_demanded_element_op1(<4 x float> %x, <4 x float> %y) { ; CHECK-LABEL: @insert_demanded_element_op1( ; CHECK-NEXT: [[INS:%.*]] = insertelement <4 x float> [[X:%.*]], float 4.300000e+01, i32 3