diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1069,6 +1069,8 @@ void ReplaceNodeResults(SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const override; + void ReplaceBITCASTResults(SDNode *N, SmallVectorImpl &Results, + SelectionDAG &DAG) const; void ReplaceExtractSubVectorResults(SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1179,6 +1179,10 @@ setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); } + // Legalize unpacked bitcasts to REINTERPRET_CAST. + for (auto VT : {MVT::nxv2i32, MVT::nxv2f32}) + setOperationAction(ISD::BITCAST, VT, Custom); + for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) { setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); @@ -3486,6 +3490,8 @@ return CallResult.first; } +static MVT getSVEContainerType(EVT ContentTy); + SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const { EVT OpVT = Op.getValueType(); @@ -3493,6 +3499,20 @@ if (useSVEForFixedLengthVectorVT(OpVT)) return LowerFixedLengthBitcastToSVE(Op, DAG); + if (OpVT == MVT::nxv2f32) { + if (Op.getOperand(0).getValueType().isInteger()) { + SDValue ExtResult = + DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), + getSVEContainerType(Op.getOperand(0).getValueType()), + Op.getOperand(0)); + return getSVESafeBitCast(MVT::nxv2f32, ExtResult, DAG); + } else { + SDValue CastResult = + getSVESafeBitCast(MVT::nxv4f32, Op.getOperand(0), DAG); + return getSVESafeBitCast(MVT::nxv2f32, CastResult, DAG); + } + } + if (OpVT != MVT::f16 && OpVT != MVT::bf16) return SDValue(); @@ -16752,11 +16772,19 @@ return true; } -static void ReplaceBITCASTResults(SDNode *N, SmallVectorImpl &Results, - SelectionDAG &DAG) { +void AArch64TargetLowering::ReplaceBITCASTResults( + SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const { SDLoc DL(N); SDValue Op = N->getOperand(0); + if (N->getValueType(0) == MVT::nxv2i32 && + Op.getValueType().isFloatingPoint()) { + SDValue CastResult = getSVESafeBitCast(MVT::nxv2i64, N->getOperand(0), DAG); + Results.push_back( + DAG.getNode(ISD::TRUNCATE, SDLoc(N), MVT::nxv2i32, CastResult)); + return; + } + if (N->getValueType(0) != MVT::i16 || (Op.getValueType() != MVT::f16 && Op.getValueType() != MVT::bf16)) return; diff --git a/llvm/test/CodeGen/AArch64/sve-bitcast.ll b/llvm/test/CodeGen/AArch64/sve-bitcast.ll --- a/llvm/test/CodeGen/AArch64/sve-bitcast.ll +++ b/llvm/test/CodeGen/AArch64/sve-bitcast.ll @@ -450,5 +450,46 @@ ret %bc } +define @bitcast_short_float_to_i32(* %p) #0 { +; CHECK-LABEL: bitcast_short_float_to_i32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: ld1d { z0.d }, p0/z, [x0] +; CHECK-NEXT: fcvt z0.s, p0/m, z0.d +; CHECK-NEXT: ret + %load = load , * %p + %trunc = fptrunc %load to + %bitcast = bitcast %trunc to + ret %bitcast +} + +define @bitcast_short_i32_to_float(* %p) #0 { +; CHECK-LABEL: bitcast_short_i32_to_float: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0] +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fcvt z0.d, p0/m, z0.s +; CHECK-NEXT: ret + %load = load , * %p + %trunc = trunc %load to + %bitcast = bitcast %trunc to + %extended = fpext %bitcast to + ret %extended +} + +define @bitcast_short_half_to_float(* %p) #0 { +; CHECK-LABEL: bitcast_short_half_to_float: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1h { z0.s }, p0/z, [x0] +; CHECK-NEXT: fadd z0.h, p0/m, z0.h, z0.h +; CHECK-NEXT: ret + %load = load , * %p + %add = fadd %load, %load + %bitcast = bitcast %add to + ret %bitcast +} + ; +bf16 is required for the bfloat version. attributes #0 = { "target-features"="+sve,+bf16" }