diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1069,6 +1069,8 @@ void ReplaceNodeResults(SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const override; + void ReplaceBITCASTResults(SDNode *N, SmallVectorImpl &Results, + SelectionDAG &DAG) const; void ReplaceExtractSubVectorResults(SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1179,6 +1179,10 @@ setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); } + // Legalize unpacked bitcasts to REINTERPRET_CAST. + for (auto VT : {MVT::nxv2i32, MVT::nxv2f32}) + setOperationAction(ISD::BITCAST, VT, Custom); + for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) { setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); @@ -3486,6 +3490,8 @@ return CallResult.first; } +static MVT getSVEContainerType(EVT ContentTy); + SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const { EVT OpVT = Op.getValueType(); @@ -3493,6 +3499,17 @@ if (useSVEForFixedLengthVectorVT(OpVT)) return LowerFixedLengthBitcastToSVE(Op, DAG); + if (OpVT == MVT::nxv2f32) { + if (Op.getOperand(0).getValueType().isInteger()) { + SDValue ExtResult = + DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), + getSVEContainerType(Op.getOperand(0).getValueType()), + Op.getOperand(0)); + return getSVESafeBitCast(MVT::nxv2f32, ExtResult, DAG); + } + return getSVESafeBitCast(MVT::nxv2f32, Op.getOperand(0), DAG); + } + if (OpVT != MVT::f16 && OpVT != MVT::bf16) return SDValue(); @@ -16752,11 +16769,19 @@ return true; } -static void ReplaceBITCASTResults(SDNode *N, SmallVectorImpl &Results, - SelectionDAG &DAG) { +void AArch64TargetLowering::ReplaceBITCASTResults( + SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const { SDLoc DL(N); SDValue Op = N->getOperand(0); + if (N->getValueType(0) == MVT::nxv2i32 && + Op.getValueType().isFloatingPoint()) { + SDValue CastResult = getSVESafeBitCast(MVT::nxv2i64, N->getOperand(0), DAG); + Results.push_back( + DAG.getNode(ISD::TRUNCATE, SDLoc(N), MVT::nxv2i32, CastResult)); + return; + } + if (N->getValueType(0) != MVT::i16 || (Op.getValueType() != MVT::f16 && Op.getValueType() != MVT::bf16)) return; @@ -18310,8 +18335,6 @@ EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType()); EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType()); - assert((VT == PackedVT || InVT == PackedInVT) && - "Cannot cast between unpacked scalable vector types!"); // Pack input if required. if (InVT != PackedInVT) diff --git a/llvm/test/CodeGen/AArch64/sve-bitcast.ll b/llvm/test/CodeGen/AArch64/sve-bitcast.ll --- a/llvm/test/CodeGen/AArch64/sve-bitcast.ll +++ b/llvm/test/CodeGen/AArch64/sve-bitcast.ll @@ -450,5 +450,39 @@ ret %bc } +define @bitcast_short_float_to_i32( %v) #0 { +; CHECK-LABEL: bitcast_short_float_to_i32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fcvt z0.s, p0/m, z0.d +; CHECK-NEXT: ret + %trunc = fptrunc %v to + %bitcast = bitcast %trunc to + ret %bitcast +} + +define @bitcast_short_i32_to_float( %v) #0 { +; CHECK-LABEL: bitcast_short_i32_to_float: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fcvt z0.d, p0/m, z0.s +; CHECK-NEXT: ret + %trunc = trunc %v to + %bitcast = bitcast %trunc to + %extended = fpext %bitcast to + ret %extended +} + +define @bitcast_short_half_to_float( %v) #0 { +; CHECK-LABEL: bitcast_short_half_to_float: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fadd z0.h, p0/m, z0.h, z0.h +; CHECK-NEXT: ret + %add = fadd %v, %v + %bitcast = bitcast %add to + ret %bitcast +} + ; +bf16 is required for the bfloat version. attributes #0 = { "target-features"="+sve,+bf16" }