diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -4001,10 +4001,30 @@ } SDValue DAGTypeLegalizer::WidenVecRes_INSERT_VECTOR_ELT(SDNode *N) { - SDValue InOp = GetWidenedVector(N->getOperand(0)); - return DAG.getNode(ISD::INSERT_VECTOR_ELT, SDLoc(N), - InOp.getValueType(), InOp, - N->getOperand(1), N->getOperand(2)); + SDValue InVector = N->getOperand(0); + SDValue WideVector = GetWidenedVector(InVector); + SDLoc DL(N); + + if (InVector.getValueType().isFixedLengthVector()) { + return DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, WideVector.getValueType(), + WideVector, N->getOperand(1), N->getOperand(2)); + } + + assert(WideVector.getValueType().isScalableVector() && + "Expected Widendned vector to be scalable type"); + unsigned MinElems = InVector.getValueType().getVectorMinNumElements(); + assert(isPowerOf2_64(MinElems) && + "Scalable vector must have minimum elements which is a power of two"); + unsigned WidenedMinElems = + WideVector.getValueType().getVectorMinNumElements(); + + MVT PtrTy = TLI.getVectorIdxTy(DAG.getDataLayout()); + SDValue IdxFactor = DAG.getConstant((WidenedMinElems / MinElems), DL, PtrTy); + SDValue Idx = N->getOperand(2); + + SDValue OffsetIndex = DAG.getNode(ISD::MUL, DL, PtrTy, Idx, IdxFactor); + return DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, WideVector.getValueType(), + WideVector, N->getOperand(1), OffsetIndex); } SDValue DAGTypeLegalizer::WidenVecRes_LOAD(SDNode *N) { diff --git a/llvm/test/CodeGen/AArch64/sve-insert-element.ll b/llvm/test/CodeGen/AArch64/sve-insert-element.ll --- a/llvm/test/CodeGen/AArch64/sve-insert-element.ll +++ b/llvm/test/CodeGen/AArch64/sve-insert-element.ll @@ -352,3 +352,63 @@ %res = insertelement undef, double %d, i64 %idx ret %res } + +; Widened Vector Insert +define @test_Unpacked_Vector_Insert_nxv1i16( %a, i16 %b) { +; CHECK-LABEL: test_Unpacked_Vector_Insert_nxv1i16: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #16 +; CHECK-NEXT: index z1.h, #0, #1 +; CHECK-NEXT: mov z2.h, w8 +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: cmpeq p0.h, p0/z, z1.h, z2.h +; CHECK-NEXT: mov z0.h, p0/m, w0 +; CHECK-NEXT: ret + %c = insertelement %a, i16 %b, i32 2 + ret %c +} + +define @test_Unpacked_Vector_Insert_nxv1f64( %a, double %b) { +; CHECK-LABEL: test_Unpacked_Vector_Insert_nxv1f64: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #2 +; CHECK-NEXT: index z2.d, #0, #1 +; CHECK-NEXT: mov z3.d, x8 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: cmpeq p0.d, p0/z, z2.d, z3.d +; CHECK-NEXT: mov z0.d, p0/m, d1 +; CHECK-NEXT: ret + %c = insertelement %a, double %b, i32 1 + ret %c +} + +define @test_Unpacked_Vector_Insert_nxv1i8( %a, i8 %b, i32 %idx) { +; CHECK-LABEL: test_Unpacked_Vector_Insert_nxv1i8: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1 +; CHECK-NEXT: sbfiz x8, x1, #4, #32 +; CHECK-NEXT: index z1.b, #0, #1 +; CHECK-NEXT: mov z2.b, w8 +; CHECK-NEXT: ptrue p0.b +; CHECK-NEXT: cmpeq p0.b, p0/z, z1.b, z2.b +; CHECK-NEXT: mov z0.b, p0/m, w0 +; CHECK-NEXT: ret + %c = insertelement %a, i8 %b, i32 %idx + ret %c +} + +define @test_Unpacked_Vector_Insert_nxv1f32( %a, float %b, i32 %idx) { +; CHECK-LABEL: test_Unpacked_Vector_Insert_nxv1f32: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $w0 killed $w0 def $x0 +; CHECK-NEXT: sbfiz x8, x0, #1, #32 +; CHECK-NEXT: index z2.d, #0, #1 +; CHECK-NEXT: mov z3.d, x8 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: cmpeq p0.d, p0/z, z2.d, z3.d +; CHECK-NEXT: mov z0.s, p0/m, s1 +; CHECK-NEXT: ret + %c = insertelement %a, float %b, i32 %idx + ret %c +} +