diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3678,6 +3678,38 @@ DAG.getTargetConstant(Pattern, DL, MVT::i32)); } +static SDValue lowerConvertToSVBool(SDValue Op, SelectionDAG &DAG) { + SDLoc dl(Op); + EVT OutVT = Op.getValueType(); + EVT InVT = Op.getOperand(1).getValueType(); + + // Return the operand if the cast isn't changing type, + // i.e. -> + if (InVT == OutVT) + return Op.getOperand(1); + + SDValue Reinterpret = + DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, OutVT, Op.getOperand(1)); + + // If the argument converted to an svbool is a ptrue or a comparison, the + // lanes introduced by the widening are zero by construction. + switch (Op.getOperand(1).getOpcode()) { + case AArch64ISD::SETCC_MERGE_ZERO: + return Reinterpret; + case ISD::INTRINSIC_WO_CHAIN: + unsigned Intr = + cast(Op.getOperand(1).getOperand(0))->getZExtValue(); + if (Intr == Intrinsic::aarch64_sve_ptrue) + return Reinterpret; + } + + // Otherwise, zero the newly introduced lanes. + SDValue Mask = getPTrue(DAG, dl, InVT, AArch64SVEPredPattern::all); + SDValue MaskReinterpret = + DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, OutVT, Mask); + return DAG.getNode(ISD::AND, dl, OutVT, Reinterpret, MaskReinterpret); +} + SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const { unsigned IntNo = cast(Op.getOperand(0))->getZExtValue(); @@ -3781,6 +3813,8 @@ case Intrinsic::aarch64_sve_convert_from_svbool: return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(), Op.getOperand(1)); + case Intrinsic::aarch64_sve_convert_to_svbool: + return lowerConvertToSVBool(Op, DAG); case Intrinsic::aarch64_sve_fneg: return DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, dl, Op.getValueType(), Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); @@ -3836,22 +3870,6 @@ case Intrinsic::aarch64_sve_neg: return DAG.getNode(AArch64ISD::NEG_MERGE_PASSTHRU, dl, Op.getValueType(), Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); - case Intrinsic::aarch64_sve_convert_to_svbool: { - EVT OutVT = Op.getValueType(); - EVT InVT = Op.getOperand(1).getValueType(); - // Return the operand if the cast isn't changing type, - // i.e. -> - if (InVT == OutVT) - return Op.getOperand(1); - // Otherwise, zero the newly introduced lanes. - SDValue Reinterpret = - DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, OutVT, Op.getOperand(1)); - SDValue Mask = getPTrue(DAG, dl, InVT, AArch64SVEPredPattern::all); - SDValue MaskReinterpret = - DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, OutVT, Mask); - return DAG.getNode(ISD::AND, dl, OutVT, Reinterpret, MaskReinterpret); - } - case Intrinsic::aarch64_sve_insr: { SDValue Scalar = Op.getOperand(2); EVT ScalarTy = Scalar.getValueType(); diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-reinterpret.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-reinterpret.ll --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-reinterpret.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-reinterpret.ll @@ -73,6 +73,31 @@ ret %out } +define @reinterpret_ptrue() { +; Reinterpreting a ptrue should not introduce an `and` instruction. +; CHECK-LABEL: reinterpret_ptrue:{{.*$}} +; CHECK: ptrue +; CHECK-NEXT: ret + %in = tail call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) + %out = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %in) + ret %out +} + +define @reinterpret_cmpgt( %a, %b) { +; Reinterpreting a comparison not introduce an `and` instruction. +; CHECK-LABEL: reinterpret_cmpgt:{{.*$}} +; CHECK: cmpgt +; CHECK-NOT: and +; CHECK-NEXT: ret + %1 = tail call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) + %2 = tail call @llvm.aarch64.sve.cmpgt.nxv8i16( %1, %a, %b) + %3 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %2) + ret %3 +} + +declare @llvm.aarch64.sve.ptrue.nxv8i1(i32 immarg) +declare @llvm.aarch64.sve.cmpgt.nxv8i16(, , ) + declare @llvm.aarch64.sve.convert.to.svbool.nxv16i1() declare @llvm.aarch64.sve.convert.to.svbool.nxv8i1() declare @llvm.aarch64.sve.convert.to.svbool.nxv4i1()