diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -806,7 +806,7 @@ void SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo, SDValue &Hi); void SplitVecRes_MLOAD(MaskedLoadSDNode *MLD, SDValue &Lo, SDValue &Hi); void SplitVecRes_MGATHER(MaskedGatherSDNode *MGT, SDValue &Lo, SDValue &Hi); - void SplitVecRes_SCALAR_TO_VECTOR(SDNode *N, SDValue &Lo, SDValue &Hi); + void SplitVecRes_ScalarOp(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_VECTOR_SHUFFLE(ShuffleVectorSDNode *N, SDValue &Lo, SDValue &Hi); @@ -862,7 +862,7 @@ SDValue WidenVecRes_LOAD(SDNode* N); SDValue WidenVecRes_MLOAD(MaskedLoadSDNode* N); SDValue WidenVecRes_MGATHER(MaskedGatherSDNode* N); - SDValue WidenVecRes_SCALAR_TO_VECTOR(SDNode* N); + SDValue WidenVecRes_ScalarOp(SDNode* N); SDValue WidenVecRes_SELECT(SDNode* N); SDValue WidenVSELECTAndMask(SDNode *N); SDValue WidenVecRes_SELECT_CC(SDNode* N); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -836,7 +836,10 @@ case ISD::FPOWI: SplitVecRes_FPOWI(N, Lo, Hi); break; case ISD::FCOPYSIGN: SplitVecRes_FCOPYSIGN(N, Lo, Hi); break; case ISD::INSERT_VECTOR_ELT: SplitVecRes_INSERT_VECTOR_ELT(N, Lo, Hi); break; - case ISD::SCALAR_TO_VECTOR: SplitVecRes_SCALAR_TO_VECTOR(N, Lo, Hi); break; + case ISD::SPLAT_VECTOR: + case ISD::SCALAR_TO_VECTOR: + SplitVecRes_ScalarOp(N, Lo, Hi); + break; case ISD::SIGN_EXTEND_INREG: SplitVecRes_InregOp(N, Lo, Hi); break; case ISD::LOAD: SplitVecRes_LOAD(cast(N), Lo, Hi); @@ -1517,13 +1520,18 @@ Hi = DAG.getNode(ISD::TRUNCATE, dl, HiVT, Hi); } -void DAGTypeLegalizer::SplitVecRes_SCALAR_TO_VECTOR(SDNode *N, SDValue &Lo, - SDValue &Hi) { +void DAGTypeLegalizer::SplitVecRes_ScalarOp(SDNode *N, SDValue &Lo, + SDValue &Hi) { EVT LoVT, HiVT; SDLoc dl(N); std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0)); - Lo = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, LoVT, N->getOperand(0)); - Hi = DAG.getUNDEF(HiVT); + Lo = DAG.getNode(N->getOpcode(), dl, LoVT, N->getOperand(0)); + if (N->getOpcode() == ISD::SCALAR_TO_VECTOR) { + Hi = DAG.getUNDEF(HiVT); + } else { + assert(N->getOpcode() == ISD::SPLAT_VECTOR && "Unexpected opcode"); + Hi = Lo; + } } void DAGTypeLegalizer::SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo, @@ -2194,16 +2202,21 @@ SDValue DAGTypeLegalizer::SplitVecOp_EXTRACT_SUBVECTOR(SDNode *N) { // We know that the extracted result type is legal. EVT SubVT = N->getValueType(0); + + if (SubVT.isScalableVector() != + N->getOperand(0).getValueType().isScalableVector()) + report_fatal_error("Extracting fixed from scalable not implemented"); + SDValue Idx = N->getOperand(1); SDLoc dl(N); SDValue Lo, Hi; GetSplitVector(N->getOperand(0), Lo, Hi); - uint64_t LoElts = Lo.getValueType().getVectorNumElements(); + uint64_t LoElts = Lo.getValueType().getVectorMinNumElements(); uint64_t IdxVal = cast(Idx)->getZExtValue(); if (IdxVal < LoElts) { - assert(IdxVal + SubVT.getVectorNumElements() <= LoElts && + assert(IdxVal + SubVT.getVectorMinNumElements() <= LoElts && "Extracted subvector crosses vector split!"); return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, Lo, Idx); } else { @@ -2733,7 +2746,10 @@ case ISD::EXTRACT_SUBVECTOR: Res = WidenVecRes_EXTRACT_SUBVECTOR(N); break; case ISD::INSERT_VECTOR_ELT: Res = WidenVecRes_INSERT_VECTOR_ELT(N); break; case ISD::LOAD: Res = WidenVecRes_LOAD(N); break; - case ISD::SCALAR_TO_VECTOR: Res = WidenVecRes_SCALAR_TO_VECTOR(N); break; + case ISD::SPLAT_VECTOR: + case ISD::SCALAR_TO_VECTOR: + Res = WidenVecRes_ScalarOp(N); + break; case ISD::SIGN_EXTEND_INREG: Res = WidenVecRes_InregOp(N); break; case ISD::VSELECT: case ISD::SELECT: Res = WidenVecRes_SELECT(N); break; @@ -3815,10 +3831,9 @@ return Res; } -SDValue DAGTypeLegalizer::WidenVecRes_SCALAR_TO_VECTOR(SDNode *N) { +SDValue DAGTypeLegalizer::WidenVecRes_ScalarOp(SDNode *N) { EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); - return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(N), - WidenVT, N->getOperand(0)); + return DAG.getNode(N->getOpcode(), SDLoc(N), WidenVT, N->getOperand(0)); } // Return true is this is a SETCC node or a strict version of it. diff --git a/llvm/test/CodeGen/AArch64/sve-vector-splat.ll b/llvm/test/CodeGen/AArch64/sve-vector-splat.ll --- a/llvm/test/CodeGen/AArch64/sve-vector-splat.ll +++ b/llvm/test/CodeGen/AArch64/sve-vector-splat.ll @@ -134,6 +134,19 @@ ret %splat } +;; Widen/split splats of wide vector types. + +define @sve_splat_12xi32(i32 %val) { +; CHECK-LABEL: @sve_splat_12xi32 +; CHECK: mov z0.s, w0 +; CHECK-NEXT: mov z1.d, z0.d +; CHECK-NEXT: mov z2.d, z0.d +; CHECK-NEXT: ret + %ins = insertelement undef, i32 %val, i32 0 + %splat = shufflevector %ins, undef, zeroinitializer + ret %splat +} + define @sve_splat_2xi1(i1 %val) { ; CHECK-LABEL: @sve_splat_2xi1 ; CHECK: sbfx x8, x0, #0, #1