diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -314,6 +314,7 @@ DUP_MERGE_PASSTHRU, INDEX_VECTOR, + // Cast between vectors of the same element type but differ in length. REINTERPRET_CAST, LD1_MERGE_ZERO, @@ -1022,6 +1023,17 @@ // NEON vector. This changes when OverrideNEON is true, allowing SVE to be // used for 64bit and 128bit vectors as well. bool useSVEForFixedLengthVectorVT(EVT VT, bool OverrideNEON = false) const; + + // With the exception of data-predicate transitions, no instructions are + // required to cast between legal scalable vector types. However: + // 1. Packed and unpacked types have different bit lengths, meaning BITCAST + // is not universally useable. + // 2. Most unpacked integer types are not legal and thus integer extends + // cannot be used to convert between unpacked and packed types. + // These can make "bitcasting" a multiphase process. REINTERPRET_CAST is used + // to transition between unpacked and packed types of the same element type, + // with BITCAST used otherwise. + SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const; }; namespace AArch64 { diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -144,6 +144,25 @@ return MVT::nxv4f32; case MVT::f64: return MVT::nxv2f64; + case MVT::bf16: + return MVT::nxv8bf16; + } +} + +// NOTE: Currently there's only a need to return integer vector types. If this +// changes then just add an extra "type" parameter. +static inline EVT getPackedSVEVectorVT(ElementCount EC) { + switch (EC.getKnownMinValue()) { + default: + llvm_unreachable("unexpected element count for vector"); + case 16: + return MVT::nxv16i8; + case 8: + return MVT::nxv8i16; + case 4: + return MVT::nxv4i32; + case 2: + return MVT::nxv2i64; } } @@ -3988,14 +4007,10 @@ !static_cast(DAG.getSubtarget()).hasBF16()) return SDValue(); - // Handle FP data + // Handle FP data by using an integer gather and casting the result. if (VT.isFloatingPoint()) { - ElementCount EC = VT.getVectorElementCount(); - auto ScalarIntVT = - MVT::getIntegerVT(AArch64::SVEBitsPerBlock / EC.getKnownMinValue()); - PassThru = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, - MVT::getVectorVT(ScalarIntVT, EC), PassThru); - + EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount()); + PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG); InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger()); } @@ -4015,7 +4030,7 @@ SDValue Gather = DAG.getNode(Opcode, DL, VTs, Ops); if (VT.isFloatingPoint()) { - SDValue Cast = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Gather); + SDValue Cast = getSVESafeBitCast(VT, Gather, DAG); return DAG.getMergeValues({Cast, Gather}, DL); } @@ -4052,15 +4067,10 @@ !static_cast(DAG.getSubtarget()).hasBF16()) return SDValue(); - // Handle FP data + // Handle FP data by casting the data so an integer scatter can be used. if (VT.isFloatingPoint()) { - VT = VT.changeVectorElementTypeToInteger(); - ElementCount EC = VT.getVectorElementCount(); - auto ScalarIntVT = - MVT::getIntegerVT(AArch64::SVEBitsPerBlock / EC.getKnownMinValue()); - StoreVal = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, - MVT::getVectorVT(ScalarIntVT, EC), StoreVal); - + EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount()); + StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG); InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger()); } @@ -17157,3 +17167,40 @@ auto Promote = DAG.getBoolExtOrTrunc(Cmp, DL, PromoteVT, InVT); return convertFromScalableVector(DAG, Op.getValueType(), Promote); } + +SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + EVT InVT = Op.getValueType(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + assert(VT.isScalableVector() && TLI.isTypeLegal(VT) && + InVT.isScalableVector() && TLI.isTypeLegal(InVT) && + "Only expect to cast between legal scalable vector types!"); + assert((VT.getVectorElementType() == MVT::i1) == + (InVT.getVectorElementType() == MVT::i1) && + "Cannot cast between data and predicate scalable vector types!"); + + if (InVT == VT) + return Op; + + if (VT.getVectorElementType() == MVT::i1) + return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op); + + EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType()); + EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType()); + assert((VT == PackedVT || InVT == PackedInVT) && + "Cannot cast between unpacked scalable vector types!"); + + // Pack input if required. + if (InVT != PackedInVT) + Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op); + + Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op); + + // Unpack result if required. + if (VT != PackedVT) + Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op); + + return Op; +} diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1721,6 +1721,7 @@ def : Pat<(nxv2f64 (bitconvert (nxv8bf16 ZPR:$src))), (nxv2f64 ZPR:$src)>; } + // These allow casting from/to unpacked predicate types. def : Pat<(nxv16i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv16i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv16i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; @@ -1735,23 +1736,17 @@ def : Pat<(nxv2i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv2i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; - def : Pat<(nxv2i64 (reinterpret_cast (nxv2f64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv2i64 (reinterpret_cast (nxv2f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv2i64 (reinterpret_cast (nxv2f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv4i32 (reinterpret_cast (nxv4f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv4i32 (reinterpret_cast (nxv4f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv2i64 (reinterpret_cast (nxv2bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv4i32 (reinterpret_cast (nxv4bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - - def : Pat<(nxv2f16 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv2f32 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv2f64 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv4f16 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv4f32 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv8f16 (reinterpret_cast (nxv8i16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv2bf16 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv4bf16 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; - def : Pat<(nxv8bf16 (reinterpret_cast (nxv8i16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + // These allow casting from/to unpacked floating-point types. + def : Pat<(nxv2f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv8f16 (reinterpret_cast (nxv2f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv4f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv8f16 (reinterpret_cast (nxv4f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv2f32 (reinterpret_cast (nxv4f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv4f32 (reinterpret_cast (nxv2f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv2bf16 (reinterpret_cast (nxv8bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv8bf16 (reinterpret_cast (nxv2bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv4bf16 (reinterpret_cast (nxv8bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; + def : Pat<(nxv8bf16 (reinterpret_cast (nxv4bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; def : Pat<(nxv16i1 (and PPR:$Ps1, PPR:$Ps2)), (AND_PPzPP (PTRUE_B 31), PPR:$Ps1, PPR:$Ps2)>;