diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1066,6 +1066,8 @@ void ReplaceNodeResults(SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const override; + void ReplaceBITCASTResults(SDNode *N, SmallVectorImpl &Results, + SelectionDAG &DAG) const; void ReplaceExtractSubVectorResults(SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1192,6 +1192,10 @@ setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); } + // Legalize unpacked bitcasts to REINTERPRET_CAST. + for (auto VT : {MVT::nxv2i32, MVT::nxv2f32}) + setOperationAction(ISD::BITCAST, VT, Custom); + for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) { setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); @@ -3508,17 +3512,30 @@ return CallResult.first; } +static MVT getSVEContainerType(EVT ContentTy); + SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const { EVT OpVT = Op.getValueType(); + EVT ArgVT = Op.getOperand(0).getValueType(); if (useSVEForFixedLengthVectorVT(OpVT)) return LowerFixedLengthBitcastToSVE(Op, DAG); + if (OpVT == MVT::nxv2f32) { + if (ArgVT.isInteger()) { + SDValue ExtResult = + DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), getSVEContainerType(ArgVT), + Op.getOperand(0)); + return getSVESafeBitCast(MVT::nxv2f32, ExtResult, DAG); + } + return getSVESafeBitCast(MVT::nxv2f32, Op.getOperand(0), DAG); + } + if (OpVT != MVT::f16 && OpVT != MVT::bf16) return SDValue(); - assert(Op.getOperand(0).getValueType() == MVT::i16); + assert(ArgVT == MVT::i16); SDLoc DL(Op); Op = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op.getOperand(0)); @@ -16866,11 +16883,18 @@ return true; } -static void ReplaceBITCASTResults(SDNode *N, SmallVectorImpl &Results, - SelectionDAG &DAG) { +void AArch64TargetLowering::ReplaceBITCASTResults( + SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const { SDLoc DL(N); SDValue Op = N->getOperand(0); + if (N->getValueType(0) == MVT::nxv2i32 && + Op.getValueType().isFloatingPoint()) { + SDValue CastResult = getSVESafeBitCast(MVT::nxv2i64, Op, DAG); + Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::nxv2i32, CastResult)); + return; + } + if (N->getValueType(0) != MVT::i16 || (Op.getValueType() != MVT::f16 && Op.getValueType() != MVT::bf16)) return; @@ -18428,8 +18452,6 @@ EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType()); EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType()); - assert((VT == PackedVT || InVT == PackedInVT) && - "Cannot cast between unpacked scalable vector types!"); // Pack input if required. if (InVT != PackedInVT) diff --git a/llvm/test/CodeGen/AArch64/sve-bitcast.ll b/llvm/test/CodeGen/AArch64/sve-bitcast.ll --- a/llvm/test/CodeGen/AArch64/sve-bitcast.ll +++ b/llvm/test/CodeGen/AArch64/sve-bitcast.ll @@ -450,5 +450,39 @@ ret %bc } +define @bitcast_short_float_to_i32( %v) #0 { +; CHECK-LABEL: bitcast_short_float_to_i32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fcvt z0.s, p0/m, z0.d +; CHECK-NEXT: ret + %trunc = fptrunc %v to + %bitcast = bitcast %trunc to + ret %bitcast +} + +define @bitcast_short_i32_to_float( %v) #0 { +; CHECK-LABEL: bitcast_short_i32_to_float: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fcvt z0.d, p0/m, z0.s +; CHECK-NEXT: ret + %trunc = trunc %v to + %bitcast = bitcast %trunc to + %extended = fpext %bitcast to + ret %extended +} + +define @bitcast_short_half_to_float( %v) #0 { +; CHECK-LABEL: bitcast_short_half_to_float: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fadd z0.h, p0/m, z0.h, z0.h +; CHECK-NEXT: ret + %add = fadd %v, %v + %bitcast = bitcast %add to + ret %bitcast +} + ; +bf16 is required for the bfloat version. attributes #0 = { "target-features"="+sve,+bf16" }