diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -952,6 +952,8 @@ SDValue WidenVecRes_EXTEND_VECTOR_INREG(SDNode* N); SDValue WidenVecRes_EXTRACT_SUBVECTOR(SDNode* N); SDValue WidenVecRes_INSERT_SUBVECTOR(SDNode *N); + SDValue WidenVecRes_VECTOR_DEINTERLEAVE(SDNode *N); + SDValue WidenVecRes_VECTOR_INTERLEAVE(SDNode *N); SDValue WidenVecRes_INSERT_VECTOR_ELT(SDNode* N); SDValue WidenVecRes_LOAD(SDNode* N); SDValue WidenVecRes_VP_LOAD(VPLoadSDNode *N); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -3932,6 +3932,10 @@ Res = WidenVecRes_INSERT_SUBVECTOR(N); break; case ISD::EXTRACT_SUBVECTOR: Res = WidenVecRes_EXTRACT_SUBVECTOR(N); break; + case ISD::VECTOR_DEINTERLEAVE: + Res = WidenVecRes_VECTOR_DEINTERLEAVE(N); + return; + return; case ISD::INSERT_VECTOR_ELT: Res = WidenVecRes_INSERT_VECTOR_ELT(N); break; case ISD::LOAD: Res = WidenVecRes_LOAD(N); break; case ISD::STEP_VECTOR: @@ -5017,6 +5021,35 @@ } } } + if (WidenVT.isScalableVector()) { + // Break down the inputs such that we can concatenate individual parts + // into one wider vector, e.g. + // + // nxv6i32 = concat(nxv3i32 t0, nxv3i32 t1) + // => + // nxv8i32 = concat(nxv1i32 (extract.. t0), + // nxv1i32 (extract.. t0), + // nxv1i32 (extract.. t0), + // nxv1i32 (extract.. t1), + // nxv1i32 (extract.. t1), + // nxv1i32 (extract.. t1), + // undef, undef) + unsigned GCD = std::gcd(InVT.getVectorMinNumElements(), + WidenVT.getVectorMinNumElements()); + EVT PartVT = + EVT::getVectorVT(*DAG.getContext(), InVT.getVectorElementType(), + ElementCount::getScalable(GCD)); + SmallVector Parts(WidenVT.getVectorMinNumElements() / GCD, + DAG.getUNDEF(PartVT)); + unsigned NumSubParts = InVT.getVectorMinNumElements() / GCD; + for (unsigned I = 0; I < NumOperands; ++I) + for (unsigned SI = 0; SI < NumSubParts; ++SI) + Parts[I * NumSubParts + SI] = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, PartVT, + GetWidenedVector(N->getOperand(I)), + DAG.getVectorIdxConstant(SI * GCD, dl)); + return DAG.getNode(ISD::CONCAT_VECTORS, dl, WidenVT, Parts); + } assert(!WidenVT.isScalableVector() && "Cannot use build vectors to widen CONCAT_VECTOR result"); @@ -5719,6 +5752,34 @@ Mask); } +SDValue DAGTypeLegalizer::WidenVecRes_VECTOR_DEINTERLEAVE(SDNode *N) { + SDLoc dl(N); + EVT InVT = N->getValueType(0); + EVT InConcVT = + EVT::getVectorVT(*DAG.getContext(), InVT.getVectorElementType(), + InVT.getVectorElementCount() * 2); + SDValue WidenVec = DAG.getNode(ISD::CONCAT_VECTORS, dl, InConcVT, + N->getOperand(0), N->getOperand(1)); + if (getTypeAction(WidenVec.getValueType()) == TargetLowering::TypeWidenVector) + WidenVec = GetWidenedVector(WidenVec); + + EVT WidenVT = WidenVec.getValueType(); + EVT SplitVT = EVT::getVectorVT( + *DAG.getContext(), WidenVT.getVectorElementType(), + WidenVT.getVectorMinNumElements() / 2, InVT.isScalableVector()); + SDValue Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SplitVT, WidenVec, + DAG.getConstant(0, dl, MVT::i64)); + SDValue Hi = DAG.getNode( + ISD::EXTRACT_SUBVECTOR, dl, SplitVT, WidenVec, + DAG.getConstant(SplitVT.getVectorMinNumElements(), dl, MVT::i64)); + + SDValue Res = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, dl, + DAG.getVTList(SplitVT, SplitVT), Lo, Hi); + SetWidenedVector(SDValue(N, 0), Res.getValue(0)); + SetWidenedVector(SDValue(N, 1), Res.getValue(1)); + return SDValue(); +} + SDValue DAGTypeLegalizer::WidenVecRes_SETCC(SDNode *N) { assert(N->getValueType(0).isVector() && N->getOperand(0).getValueType().isVector() && diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -11589,6 +11589,7 @@ void SelectionDAGBuilder::visitVectorInterleave(const CallInst &I) { auto DL = getCurSDLoc(); EVT InVT = getValue(I.getOperand(0)).getValueType(); + EVT WideVT = InVT; SDValue InVec0 = getValue(I.getOperand(0)); SDValue InVec1 = getValue(I.getOperand(1)); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); @@ -11604,10 +11605,28 @@ return; } + // The resulting vector is wrongly concatenated/ordered if it doesn't widen + // the vector first if needed. + LLVMContext &Ctx = *DAG.getContext(); + if (TLI.getTypeAction(Ctx, InVT) == TargetLowering::TypeWidenVector) { + WideVT = TLI.getTypeToTransformTo(Ctx, InVT); + InVec0 = widenVectorToPartType(DAG, InVec0, DL, WideVT); + InVec1 = widenVectorToPartType(DAG, InVec1, DL, WideVT); + OutVT = InVec0.getValueType().getDoubleNumVectorElementsVT(Ctx); + } + SDValue Res = DAG.getNode(ISD::VECTOR_INTERLEAVE, DL, - DAG.getVTList(InVT, InVT), InVec0, InVec1); + DAG.getVTList(WideVT, WideVT), InVec0, InVec1); Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, OutVT, Res.getValue(0), Res.getValue(1)); + + // Return to the original vector size, before widening. + if (InVT.getVectorElementCount() != WideVT.getVectorElementCount()) { + OutVT = TLI.getValueType(DAG.getDataLayout(), I.getType()); + Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OutVT, Res, + DAG.getVectorIdxConstant(0, DL)); + } + setValue(&I, Res); return; } diff --git a/llvm/test/CodeGen/AArch64/sve-vector-deinterleave.ll b/llvm/test/CodeGen/AArch64/sve-vector-deinterleave.ll --- a/llvm/test/CodeGen/AArch64/sve-vector-deinterleave.ll +++ b/llvm/test/CodeGen/AArch64/sve-vector-deinterleave.ll @@ -244,6 +244,27 @@ ret {, } %retval } +; Widen illegal type size + +define {, } @vector_deinterleave_nxv6i64_nxv12i64( %vec) { +; CHECK-LABEL: vector_deinterleave_nxv6i64_nxv12i64: +; CHECK: // %bb.0: +; CHECK-NEXT: uzp1 z6.d, z4.d, z5.d +; CHECK-NEXT: uzp1 z7.d, z2.d, z3.d +; CHECK-NEXT: uzp1 z24.d, z0.d, z1.d +; CHECK-NEXT: uzp2 z25.d, z0.d, z1.d +; CHECK-NEXT: uzp2 z26.d, z2.d, z3.d +; CHECK-NEXT: uzp2 z5.d, z4.d, z5.d +; CHECK-NEXT: mov z0.d, z24.d +; CHECK-NEXT: mov z1.d, z7.d +; CHECK-NEXT: mov z2.d, z6.d +; CHECK-NEXT: mov z3.d, z25.d +; CHECK-NEXT: mov z4.d, z26.d +; CHECK-NEXT: ret +%retval = call {, } @llvm.experimental.vector.deinterleave2.nxv12i64( %vec) +ret {, } %retval +} + ; Floating declarations declare {,} @llvm.experimental.vector.deinterleave2.nxv4f16() @@ -272,3 +293,5 @@ declare {, } @llvm.experimental.vector.deinterleave2.nxv16i8() declare {, } @llvm.experimental.vector.deinterleave2.nxv8i16() declare {, } @llvm.experimental.vector.deinterleave2.nxv4i32() + +declare {, } @llvm.experimental.vector.deinterleave2.nxv12i64() diff --git a/llvm/test/CodeGen/AArch64/sve-vector-interleave.ll b/llvm/test/CodeGen/AArch64/sve-vector-interleave.ll --- a/llvm/test/CodeGen/AArch64/sve-vector-interleave.ll +++ b/llvm/test/CodeGen/AArch64/sve-vector-interleave.ll @@ -224,6 +224,20 @@ ret %retval } +; Widen illegal type size + +define @interleave2_nxv6i32( %vec0, %vec1) nounwind { +; CHECK-LABEL: interleave2_nxv6i32: +; CHECK: // %bb.0: +; CHECK-NEXT: zip1 z2.s, z0.s, z1.s +; CHECK-NEXT: zip2 z1.s, z0.s, z1.s +; CHECK-NEXT: mov z0.d, z2.d +; CHECK-NEXT: ret + %retval = call @llvm.experimental.vector.interleave2.nxv6i32( %vec0, %vec1) + ret %retval +} + + ; Float declarations declare @llvm.experimental.vector.interleave2.nxv4f16(, ) declare @llvm.experimental.vector.interleave2.nxv8f16(, ) @@ -251,3 +265,5 @@ declare @llvm.experimental.vector.interleave2.nxv16i8(, ) declare @llvm.experimental.vector.interleave2.nxv8i16(, ) declare @llvm.experimental.vector.interleave2.nxv4i32(, ) + +declare @llvm.experimental.vector.interleave2.nxv6i32(, )