diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -407,6 +407,9 @@ bool tryReadRegister(SDNode *N); bool tryWriteRegister(SDNode *N); + bool trySelectCastFixedLengthToScalableVector(SDNode *N); + bool trySelectCastScalableToFixedLengthVector(SDNode *N); + // Include the pieces autogenerated from the target description. #include "AArch64GenDAGISel.inc" @@ -4163,61 +4166,64 @@ ReplaceNode(N, N3); } -// NOTE: We cannot use EXTRACT_SUBREG in all cases because the fixed length -// vector types larger than NEON don't have a matching SubRegIndex. -static SDNode *extractSubReg(SelectionDAG *DAG, EVT VT, SDValue V) { - assert(V.getValueType().isScalableVector() && - V.getValueType().getSizeInBits().getKnownMinValue() == - AArch64::SVEBitsPerBlock && - "Expected to extract from a packed scalable vector!"); - assert(VT.isFixedLengthVector() && - "Expected to extract a fixed length vector!"); +bool AArch64DAGToDAGISel::trySelectCastFixedLengthToScalableVector(SDNode *N) { + assert(N->getOpcode() == ISD::INSERT_SUBVECTOR && "Invalid Node!"); - SDLoc DL(V); - switch (VT.getSizeInBits()) { - case 64: { - auto SubReg = DAG->getTargetConstant(AArch64::dsub, DL, MVT::i32); - return DAG->getMachineNode(TargetOpcode::EXTRACT_SUBREG, DL, VT, V, SubReg); - } - case 128: { - auto SubReg = DAG->getTargetConstant(AArch64::zsub, DL, MVT::i32); - return DAG->getMachineNode(TargetOpcode::EXTRACT_SUBREG, DL, VT, V, SubReg); - } - default: { - auto RC = DAG->getTargetConstant(AArch64::ZPRRegClassID, DL, MVT::i64); - return DAG->getMachineNode(TargetOpcode::COPY_TO_REGCLASS, DL, VT, V, RC); - } - } -} + // Bail when not a "cast" like insert_subvector. + if (cast(N->getOperand(2))->getZExtValue() != 0) + return false; + if (!N->getOperand(0).isUndef()) + return false; + + // Bail when normal isel should do the job. + EVT VT = N->getValueType(0); + EVT InVT = N->getOperand(1).getValueType(); + if (VT.isFixedLengthVector() || InVT.isScalableVector()) + return false; + if (InVT.getSizeInBits() <= 128) + return false; + + // NOTE: We can only get here when doing fixed length SVE code generation. + // We do manual selection because the types involved are not linked to real + // registers (despite being legal) and must be coerced into SVE registers. -// NOTE: We cannot use INSERT_SUBREG in all cases because the fixed length -// vector types larger than NEON don't have a matching SubRegIndex. -static SDNode *insertSubReg(SelectionDAG *DAG, EVT VT, SDValue V) { - assert(VT.isScalableVector() && - VT.getSizeInBits().getKnownMinValue() == AArch64::SVEBitsPerBlock && + assert(VT.getSizeInBits().getKnownMinValue() == AArch64::SVEBitsPerBlock && "Expected to insert into a packed scalable vector!"); - assert(V.getValueType().isFixedLengthVector() && - "Expected to insert a fixed length vector!"); - SDLoc DL(V); - switch (V.getValueType().getSizeInBits()) { - case 64: { - auto SubReg = DAG->getTargetConstant(AArch64::dsub, DL, MVT::i32); - auto Container = DAG->getMachineNode(TargetOpcode::IMPLICIT_DEF, DL, VT); - return DAG->getMachineNode(TargetOpcode::INSERT_SUBREG, DL, VT, - SDValue(Container, 0), V, SubReg); - } - case 128: { - auto SubReg = DAG->getTargetConstant(AArch64::zsub, DL, MVT::i32); - auto Container = DAG->getMachineNode(TargetOpcode::IMPLICIT_DEF, DL, VT); - return DAG->getMachineNode(TargetOpcode::INSERT_SUBREG, DL, VT, - SDValue(Container, 0), V, SubReg); - } - default: { - auto RC = DAG->getTargetConstant(AArch64::ZPRRegClassID, DL, MVT::i64); - return DAG->getMachineNode(TargetOpcode::COPY_TO_REGCLASS, DL, VT, V, RC); - } - } + SDLoc DL(N); + auto RC = CurDAG->getTargetConstant(AArch64::ZPRRegClassID, DL, MVT::i64); + ReplaceNode(N, CurDAG->getMachineNode(TargetOpcode::COPY_TO_REGCLASS, DL, VT, + N->getOperand(1), RC)); + return true; +} + +bool AArch64DAGToDAGISel::trySelectCastScalableToFixedLengthVector(SDNode *N) { + assert(N->getOpcode() == ISD::EXTRACT_SUBVECTOR && "Invalid Node!"); + + // Bail when not a "cast" like extract_subvector. + if (cast(N->getOperand(1))->getZExtValue() != 0) + return false; + + // Bail when normal isel can do the job. + EVT VT = N->getValueType(0); + EVT InVT = N->getOperand(0).getValueType(); + if (VT.isScalableVector() || InVT.isFixedLengthVector()) + return false; + if (VT.getSizeInBits() <= 128) + return false; + + // NOTE: We can only get here when doing fixed length SVE code generation. + // We do manual selection because the types involved are not linked to real + // registers (despite being legal) and must be coerced into SVE registers. + + assert(InVT.getSizeInBits().getKnownMinValue() == AArch64::SVEBitsPerBlock && + "Expected to extract from a packed scalable vector!"); + + SDLoc DL(N); + auto RC = CurDAG->getTargetConstant(AArch64::ZPRRegClassID, DL, MVT::i64); + ReplaceNode(N, CurDAG->getMachineNode(TargetOpcode::COPY_TO_REGCLASS, DL, VT, + N->getOperand(0), RC)); + return true; } void AArch64DAGToDAGISel::Select(SDNode *Node) { @@ -4296,49 +4302,15 @@ break; case ISD::EXTRACT_SUBVECTOR: { - // Bail when not a "cast" like extract_subvector. - if (cast(Node->getOperand(1))->getZExtValue() != 0) - break; - - // Bail when normal isel can do the job. - EVT InVT = Node->getOperand(0).getValueType(); - if (VT.isScalableVector() || InVT.isFixedLengthVector()) - break; - - // NOTE: We can only get here when doing fixed length SVE code generation. - // We do manual selection because the types involved are not linked to real - // registers (despite being legal) and must be coerced into SVE registers. - // - // NOTE: If the above changes, be aware that selection will still not work - // because the td definition of extract_vector does not support extracting - // a fixed length vector from a scalable vector. - - ReplaceNode(Node, extractSubReg(CurDAG, VT, Node->getOperand(0))); - return; + if (trySelectCastScalableToFixedLengthVector(Node)) + return; + break; } case ISD::INSERT_SUBVECTOR: { - // Bail when not a "cast" like insert_subvector. - if (cast(Node->getOperand(2))->getZExtValue() != 0) - break; - if (!Node->getOperand(0).isUndef()) - break; - - // Bail when normal isel should do the job. - EVT InVT = Node->getOperand(1).getValueType(); - if (VT.isFixedLengthVector() || InVT.isScalableVector()) - break; - - // NOTE: We can only get here when doing fixed length SVE code generation. - // We do manual selection because the types involved are not linked to real - // registers (despite being legal) and must be coerced into SVE registers. - // - // NOTE: If the above changes, be aware that selection will still not work - // because the td definition of insert_vector does not support inserting a - // fixed length vector into a scalable vector. - - ReplaceNode(Node, insertSubReg(CurDAG, VT, Node->getOperand(1))); - return; + if (trySelectCastFixedLengthToScalableVector(Node)) + return; + break; } case ISD::Constant: { diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1830,6 +1830,22 @@ def : Pat<(nxv2bf16 (extract_subvector (nxv8bf16 ZPR:$Zs), (i64 6))), (UUNPKHI_ZZ_D (UUNPKHI_ZZ_S ZPR:$Zs))>; + // extract/insert 64-bit fixed length vector from/into a scalable vector + foreach VT = [v8i8, v4i16, v2i32, v1i64, v4f16, v2f32, v1f64, v4bf16] in { + def : Pat<(VT (vector_extract_subvec (SVEContainerVT.Value ZPR:$Zs), (i64 0))), + (EXTRACT_SUBREG ZPR:$Zs, dsub)>; + def : Pat<(SVEContainerVT.Value (vector_insert_subvec undef, (VT V64:$src), (i64 0))), + (INSERT_SUBREG (IMPLICIT_DEF), $src, dsub)>; + } + + // extract/insert 128-bit fixed length vector from/into a scalable vector + foreach VT = [v16i8, v8i16, v4i32, v2i64, v8f16, v4f32, v2f64, v8bf16] in { + def : Pat<(VT (vector_extract_subvec (SVEContainerVT.Value ZPR:$Zs), (i64 0))), + (EXTRACT_SUBREG ZPR:$Zs, zsub)>; + def : Pat<(SVEContainerVT.Value (vector_insert_subvec undef, (VT V128:$src), (i64 0))), + (INSERT_SUBREG (IMPLICIT_DEF), $src, zsub)>; + } + // Concatenate two predicates. def : Pat<(nxv2i1 (concat_vectors nxv1i1:$p1, nxv1i1:$p2)), (UZP1_PPP_D $p1, $p2)>; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -10,6 +10,36 @@ // //===----------------------------------------------------------------------===// +// Helper class to find the largest legal scalable vector type that can hold VT. +// Non-matches return VT, which often means VT is the container type. +class SVEContainerVT { + ValueType Value = !cond( + // fixed length vectors + !eq(VT, v8i8): nxv16i8, + !eq(VT, v16i8): nxv16i8, + !eq(VT, v4i16): nxv8i16, + !eq(VT, v8i16): nxv8i16, + !eq(VT, v2i32): nxv4i32, + !eq(VT, v4i32): nxv4i32, + !eq(VT, v1i64): nxv2i64, + !eq(VT, v2i64): nxv2i64, + !eq(VT, v4f16): nxv8f16, + !eq(VT, v8f16): nxv8f16, + !eq(VT, v2f32): nxv4f32, + !eq(VT, v4f32): nxv4f32, + !eq(VT, v1f64): nxv2f64, + !eq(VT, v2f64): nxv2f64, + !eq(VT, v4bf16): nxv8bf16, + !eq(VT, v8bf16): nxv8bf16, + // unpacked scalable vectors + !eq(VT, nxv2f16): nxv8f16, + !eq(VT, nxv4f16): nxv8f16, + !eq(VT, nxv2f32): nxv4f32, + !eq(VT, nxv2bf16): nxv8bf16, + !eq(VT, nxv4bf16): nxv8bf16, + true : VT); +} + def SDT_AArch64Setcc : SDTypeProfile<1, 4, [ SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>, SDTCVecEltisVT<0, i1>, SDTCVecEltisVT<1, i1>, SDTCisSameAs<2, 3>, @@ -2804,16 +2834,10 @@ def NAME : sve_fp_2op_p_zd, SVEPseudo2Instr; // convert vt1 to a packed type for the intrinsic patterns - defvar packedvt1 = !cond(!eq(!cast(vt1), "nxv2f16"): nxv8f16, - !eq(!cast(vt1), "nxv4f16"): nxv8f16, - !eq(!cast(vt1), "nxv2f32"): nxv4f32, - 1 : vt1); + defvar packedvt1 = SVEContainerVT.Value; // convert vt3 to a packed type for the intrinsic patterns - defvar packedvt3 = !cond(!eq(!cast(vt3), "nxv2f16"): nxv8f16, - !eq(!cast(vt3), "nxv4f16"): nxv8f16, - !eq(!cast(vt3), "nxv2f32"): nxv4f32, - 1 : vt3); + defvar packedvt3 = SVEContainerVT.Value; def : SVE_3_Op_Pat(NAME)>; def : SVE_1_Op_Passthru_Pat(NAME)>; @@ -2833,10 +2857,7 @@ SVEPseudo2Instr; // convert vt1 to a packed type for the intrinsic patterns - defvar packedvt1 = !cond(!eq(!cast(vt1), "nxv2f16"): nxv8f16, - !eq(!cast(vt1), "nxv4f16"): nxv8f16, - !eq(!cast(vt1), "nxv2f32"): nxv4f32, - 1 : vt1); + defvar packedvt1 = SVEContainerVT.Value; def : SVE_3_Op_Pat(NAME)>; def : SVE_1_Op_Passthru_Round_Pat(NAME)>;