Index: llvm/include/llvm/IR/IntrinsicsAArch64.td =================================================================== --- llvm/include/llvm/IR/IntrinsicsAArch64.td +++ llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -1634,6 +1634,18 @@ def int_aarch64_sve_ptest_last : AdvSIMD_SVE_PTEST_Intrinsic; // +// Reinterpreting data +// + +def int_aarch64_sve_convert_from_svbool : Intrinsic<[llvm_anyvector_ty], + [llvm_nxv16i1_ty], + [IntrNoMem]>; + +def int_aarch64_sve_convert_to_svbool : Intrinsic<[llvm_nxv16i1_ty], + [llvm_anyvector_ty], + [IntrNoMem]>; + +// // Gather loads: scalar base + vector offsets // Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -215,6 +215,8 @@ PTEST, PTRUE, + REINTERPRET_CAST, + LDNF1, LDNF1S, LDFF1, Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1339,6 +1339,7 @@ case AArch64ISD::LASTA: return "AArch64ISD::LASTA"; case AArch64ISD::LASTB: return "AArch64ISD::LASTB"; case AArch64ISD::REV: return "AArch64ISD::REV"; + case AArch64ISD::REINTERPRET_CAST: return "AArch64ISD::REINTERPRET_CAST"; case AArch64ISD::TBL: return "AArch64ISD::TBL"; case AArch64ISD::NOT: return "AArch64ISD::NOT"; case AArch64ISD::BIT: return "AArch64ISD::BIT"; @@ -2950,6 +2951,12 @@ DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1)); } +static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT, + int Pattern) { + return DAG.getNode(AArch64ISD::PTRUE, DL, VT, + DAG.getTargetConstant(Pattern, DL, MVT::i32)); +} + SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const { unsigned IntNo = cast(Op.getOperand(0))->getZExtValue(); @@ -3037,6 +3044,21 @@ case Intrinsic::aarch64_sve_ptrue: return DAG.getNode(AArch64ISD::PTRUE, dl, Op.getValueType(), Op.getOperand(1)); + case Intrinsic::aarch64_sve_convert_from_svbool: + return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(), + Op.getOperand(1)); + case Intrinsic::aarch64_sve_convert_to_svbool: { + EVT VT = Op.getValueType(); + SDValue Reinterpret = DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, VT, + Op.getOperand(1)); + // Return the reintepret if the cast isn't changing type, + // i.e. -> + if (VT == Op.getOperand(1).getValueType()) + return Reinterpret; + // Otherwise, zero the newly introduced lanes. + SDValue Mask = getPTrue(DAG, dl, VT, AArch64SVEPredPattern::all); + return DAG.getNode(ISD::AND, dl, VT, Reinterpret, Mask); + } case Intrinsic::aarch64_sve_insr: { SDValue Scalar = Op.getOperand(2); @@ -7451,9 +7473,12 @@ SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64); return DAG.getNode(AArch64ISD::DUP, dl, VT, SplatVal); case MVT::i1: { + if (auto CSplatVal = dyn_cast(SplatVal)) + if (CSplatVal->isNullValue()) + return SDValue(DAG.getMachineNode(AArch64::PFALSE, dl, VT), 0); // The general case of i1. There isn't any natural way to do this, // so we use some trickery with whilelo. - // TODO: Add special cases for splat of constant true/false. + // TODO: Add special case for splat of constant true. SplatVal = DAG.getAnyExtOrTrunc(SplatVal, dl, MVT::i64); SplatVal = DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, MVT::i64, SplatVal, DAG.getValueType(MVT::i1)); Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -96,6 +96,8 @@ def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>; def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>; +def reinterpret_cast : SDNode<"AArch64ISD::REINTERPRET_CAST", SDTUnaryOp>; + let Predicates = [HasSVE] in { defm RDFFR_PPz : sve_int_rdffr_pred<0b0, "rdffr", int_aarch64_sve_rdffr_z>; @@ -1201,6 +1203,29 @@ def : Pat<(nxv2f64 (bitconvert (nxv8f16 ZPR:$src))), (nxv2f64 ZPR:$src)>; def : Pat<(nxv2f64 (bitconvert (nxv4f32 ZPR:$src))), (nxv2f64 ZPR:$src)>; + def : Pat<(nxv16i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv16i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv16i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv16i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv8i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv8i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv8i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv4i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv4i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv4i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv2i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv2i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv2i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + + def : Pat<(nxv16i1 (and PPR:$Ps1, PPR:$Ps2)), + (AND_PPzPP (PTRUE_B 31), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv8i1 (and PPR:$Ps1, PPR:$Ps2)), + (AND_PPzPP (PTRUE_H 31), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv4i1 (and PPR:$Ps1, PPR:$Ps2)), + (AND_PPzPP (PTRUE_S 31), PPR:$Ps1, PPR:$Ps2)>; + def : Pat<(nxv2i1 (and PPR:$Ps1, PPR:$Ps2)), + (AND_PPzPP (PTRUE_D 31), PPR:$Ps1, PPR:$Ps2)>; + // Add more complex addressing modes here as required multiclass pred_load { Index: llvm/test/CodeGen/AArch64/sve-intrinsics-reinterpret.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-intrinsics-reinterpret.ll @@ -0,0 +1,89 @@ +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s + +; +; Converting to svbool_t () +; + +define @reinterpret_bool_b2b() { +; CHECK-LABEL: reinterpret_bool_b2b: +; CHECK: pfalse p0.b +; CHECK-NEXT: ret + %out = call @llvm.aarch64.sve.convert.to.svbool.nxv16i1( zeroinitializer) + ret %out +} + +define @reinterpret_bool_h2b() { +; CHECK-LABEL: reinterpret_bool_h2b: +; CHECK: pfalse p0.b +; CHECK-NEXT: ptrue p1.b +; CHECK-NEXT: and p0.b, p1/z, p0.b, p1.b +; CHECK-NEXT: ret + %out = call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( zeroinitializer) + ret %out +} + +define @reinterpret_bool_s2b() { +; CHECK-LABEL: reinterpret_bool_s2b: +; CHECK: pfalse p0.b +; CHECK-NEXT: ptrue p1.b +; CHECK-NEXT: and p0.b, p1/z, p0.b, p1.b +; CHECK-NEXT: ret + %out = call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( zeroinitializer) + ret %out +} + +define @reinterpret_bool_d2b() { +; CHECK-LABEL: reinterpret_bool_d2b: +; CHECK: pfalse p0.b +; CHECK-NEXT: ptrue p1.b +; CHECK-NEXT: and p0.b, p1/z, p0.b, p1.b +; CHECK-NEXT: ret + %out = call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( zeroinitializer) + ret %out +} + +; +; Converting from svbool_t +; + +define @reinterpret_bool_b2b_from() { +; CHECK-LABEL: reinterpret_bool_b2b_from: +; CHECK: pfalse p0.b +; CHECK-NEXT: ret + %out = call @llvm.aarch64.sve.convert.from.svbool.nxv16i1( zeroinitializer) + ret %out +} + +define @reinterpret_bool_b2h() { +; CHECK-LABEL: reinterpret_bool_b2h: +; CHECK: pfalse p0.b +; CHECK-NEXT: ret + %out = call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( zeroinitializer) + ret %out +} + +define @reinterpret_bool_b2s() { +; CHECK-LABEL: reinterpret_bool_b2s: +; CHECK: pfalse p0.b +; CHECK-NEXT: ret + %out = call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( zeroinitializer) + ret %out +} + +define @reinterpret_bool_b2d() { +; CHECK-LABEL: reinterpret_bool_b2d: +; CHECK: pfalse p0.b +; CHECK-NEXT: ret + %out = call @llvm.aarch64.sve.convert.from.svbool.nxv2i1( zeroinitializer) + ret %out +} + +declare @llvm.aarch64.sve.convert.to.svbool.nxv16i1() +declare @llvm.aarch64.sve.convert.to.svbool.nxv8i1() +declare @llvm.aarch64.sve.convert.to.svbool.nxv4i1() +declare @llvm.aarch64.sve.convert.to.svbool.nxv2i1() + +declare @llvm.aarch64.sve.convert.from.svbool.nxv16i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv8i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv4i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv2i1()