diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3690,6 +3690,37 @@ DAG.getTargetConstant(Pattern, DL, MVT::i32)); } +static SDValue lowerConvertToSVBool(SDValue Op, SelectionDAG &DAG) { + SDLoc DL(Op); + EVT OutVT = Op.getValueType(); + SDValue InOp = Op.getOperand(1); + EVT InVT = InOp.getValueType(); + + // Return the operand if the cast isn't changing type, + // i.e. -> + if (InVT == OutVT) + return InOp; + + SDValue Reinterpret = + DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, InOp); + + // If the argument converted to an svbool is a ptrue or a comparison, the + // lanes introduced by the widening are zero by construction. + switch (InOp.getOpcode()) { + case AArch64ISD::SETCC_MERGE_ZERO: + return Reinterpret; + case ISD::INTRINSIC_WO_CHAIN: + if (InOp.getConstantOperandVal(0) == Intrinsic::aarch64_sve_ptrue) + return Reinterpret; + } + + // Otherwise, zero the newly introduced lanes. + SDValue Mask = getPTrue(DAG, DL, InVT, AArch64SVEPredPattern::all); + SDValue MaskReinterpret = + DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, Mask); + return DAG.getNode(ISD::AND, DL, OutVT, Reinterpret, MaskReinterpret); +} + SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const { unsigned IntNo = cast(Op.getOperand(0))->getZExtValue(); @@ -3793,6 +3824,8 @@ case Intrinsic::aarch64_sve_convert_from_svbool: return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(), Op.getOperand(1)); + case Intrinsic::aarch64_sve_convert_to_svbool: + return lowerConvertToSVBool(Op, DAG); case Intrinsic::aarch64_sve_fneg: return DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, dl, Op.getValueType(), Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); @@ -3848,22 +3881,6 @@ case Intrinsic::aarch64_sve_neg: return DAG.getNode(AArch64ISD::NEG_MERGE_PASSTHRU, dl, Op.getValueType(), Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); - case Intrinsic::aarch64_sve_convert_to_svbool: { - EVT OutVT = Op.getValueType(); - EVT InVT = Op.getOperand(1).getValueType(); - // Return the operand if the cast isn't changing type, - // i.e. -> - if (InVT == OutVT) - return Op.getOperand(1); - // Otherwise, zero the newly introduced lanes. - SDValue Reinterpret = - DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, OutVT, Op.getOperand(1)); - SDValue Mask = getPTrue(DAG, dl, InVT, AArch64SVEPredPattern::all); - SDValue MaskReinterpret = - DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, OutVT, Mask); - return DAG.getNode(ISD::AND, dl, OutVT, Reinterpret, MaskReinterpret); - } - case Intrinsic::aarch64_sve_insr: { SDValue Scalar = Op.getOperand(2); EVT ScalarTy = Scalar.getValueType(); diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-reinterpret.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-reinterpret.ll --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-reinterpret.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-reinterpret.ll @@ -1,3 +1,4 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve < %s | FileCheck %s ; @@ -6,37 +7,41 @@ define @reinterpret_bool_from_b( %pg) { ; CHECK-LABEL: reinterpret_bool_from_b: -; CHECK: ret +; CHECK: // %bb.0: +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.convert.to.svbool.nxv16i1( %pg) ret %out } define @reinterpret_bool_from_h( %pg) { ; CHECK-LABEL: reinterpret_bool_from_h: -; CHECK: ptrue p1.h -; CHECK-NEXT: ptrue p2.b -; CHECK-NEXT: and p0.b, p2/z, p0.b, p1.b -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p1.h +; CHECK-NEXT: ptrue p2.b +; CHECK-NEXT: and p0.b, p2/z, p0.b, p1.b +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %pg) ret %out } define @reinterpret_bool_from_s( %pg) { ; CHECK-LABEL: reinterpret_bool_from_s: -; CHECK: ptrue p1.s -; CHECK-NEXT: ptrue p2.b -; CHECK-NEXT: and p0.b, p2/z, p0.b, p1.b -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: ptrue p2.b +; CHECK-NEXT: and p0.b, p2/z, p0.b, p1.b +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %pg) ret %out } define @reinterpret_bool_from_d( %pg) { ; CHECK-LABEL: reinterpret_bool_from_d: -; CHECK: ptrue p1.d -; CHECK-NEXT: ptrue p2.b -; CHECK-NEXT: and p0.b, p2/z, p0.b, p1.b -; CHECK-NEXT: ret +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: ptrue p2.b +; CHECK-NEXT: and p0.b, p2/z, p0.b, p1.b +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.convert.to.svbool.nxv2i1( %pg) ret %out } @@ -47,32 +52,61 @@ define @reinterpret_bool_to_b( %pg) { ; CHECK-LABEL: reinterpret_bool_to_b: -; CHECK: ret +; CHECK: // %bb.0: +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.convert.from.svbool.nxv16i1( %pg) ret %out } define @reinterpret_bool_to_h( %pg) { ; CHECK-LABEL: reinterpret_bool_to_h: -; CHECK: ret +; CHECK: // %bb.0: +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %pg) ret %out } define @reinterpret_bool_to_s( %pg) { ; CHECK-LABEL: reinterpret_bool_to_s: -; CHECK: ret +; CHECK: // %bb.0: +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( %pg) ret %out } define @reinterpret_bool_to_d( %pg) { ; CHECK-LABEL: reinterpret_bool_to_d: -; CHECK: ret +; CHECK: // %bb.0: +; CHECK-NEXT: ret %out = call @llvm.aarch64.sve.convert.from.svbool.nxv2i1( %pg) ret %out } +; Reinterpreting a ptrue should not introduce an `and` instruction. +define @reinterpret_ptrue() { +; CHECK-LABEL: reinterpret_ptrue: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: ret + %in = tail call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) + %out = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %in) + ret %out +} + +; Reinterpreting a comparison not introduce an `and` instruction. +define @reinterpret_cmpgt( %p, %a, %b) { +; CHECK-LABEL: reinterpret_cmpgt: +; CHECK: // %bb.0: +; CHECK-NEXT: cmpgt p0.h, p0/z, z0.h, z1.h +; CHECK-NEXT: ret + %1 = tail call @llvm.aarch64.sve.cmpgt.nxv8i16( %p, %a, %b) + %2 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %1) + ret %2 +} + +declare @llvm.aarch64.sve.ptrue.nxv8i1(i32 immarg) +declare @llvm.aarch64.sve.cmpgt.nxv8i16(, , ) + declare @llvm.aarch64.sve.convert.to.svbool.nxv16i1() declare @llvm.aarch64.sve.convert.to.svbool.nxv8i1() declare @llvm.aarch64.sve.convert.to.svbool.nxv4i1()