diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -926,6 +926,7 @@ SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const; SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerZERO_EXTEND(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1202,6 +1202,7 @@ setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); + setOperationAction(ISD::ZERO_EXTEND, VT, Custom); setOperationAction(ISD::UMUL_LOHI, VT, Expand); setOperationAction(ISD::SMUL_LOHI, VT, Expand); @@ -5527,6 +5528,26 @@ return SDValue(); } +SDValue AArch64TargetLowering::LowerZERO_EXTEND(SDValue Op, SelectionDAG &DAG) const { + assert(Op->getOpcode() == ISD::ZERO_EXTEND && "Expected ZERO_EXTEND"); + + if (Op.getValueType().isFixedLengthVector()) + return LowerFixedLengthVectorIntExtendToSVE(Op, DAG); + + // Try to lower to VSELECT to allow zext to transform into + // a predicated instruction like add, sub or mul. + SDValue Value = Op->getOperand(0); + if (!Value->getValueType(0).isScalableVector() || + Value->getValueType(0).getScalarType() != MVT::i1) + return SDValue(); + + SDLoc DL = SDLoc(Op); + EVT VT = Op->getValueType(0); + SDValue Ones = DAG.getConstant(1, DL, VT); + SDValue Zeros = DAG.getConstant(0, DL, VT); + return DAG.getNode(ISD::VSELECT, DL, VT, Value, Ones, Zeros); +} + SDValue AArch64TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { LLVM_DEBUG(dbgs() << "Custom lowering: "); @@ -5739,8 +5760,9 @@ return LowerVSCALE(Op, DAG); case ISD::ANY_EXTEND: case ISD::SIGN_EXTEND: - case ISD::ZERO_EXTEND: return LowerFixedLengthVectorIntExtendToSVE(Op, DAG); + case ISD::ZERO_EXTEND: + return LowerZERO_EXTEND(Op, DAG); case ISD::SIGN_EXTEND_INREG: { // Only custom lower when ExtraVT has a legal byte based element type. EVT ExtraVT = cast(Op.getOperand(1))->getVT(); diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -512,13 +512,13 @@ defm FNEG_ZPmZ : sve_int_un_pred_arit_1_fp<0b101, "fneg", AArch64fneg_mt>; // zext(cmpeq(x, splat(0))) -> cnot(x) - def : Pat<(nxv16i8 (zext (nxv16i1 (AArch64setcc_z (nxv16i1 (SVEAllActive):$Pg), nxv16i8:$Op2, (SVEDup0), SETEQ)))), + def : Pat<(nxv16i8 (vselect (nxv16i1 (AArch64setcc_z (nxv16i1 (SVEAllActive):$Pg), nxv16i8:$Op2, (SVEDup0), SETEQ)), (nxv16i8 (splat_vector (i32 1))), (nxv16i8 (splat_vector (i32 0))))), (CNOT_ZPmZ_B $Op2, $Pg, $Op2)>; - def : Pat<(nxv8i16 (zext (nxv8i1 (AArch64setcc_z (nxv8i1 (SVEAllActive):$Pg), nxv8i16:$Op2, (SVEDup0), SETEQ)))), + def : Pat<(nxv8i16 (vselect (nxv8i1 (AArch64setcc_z (nxv8i1 (SVEAllActive):$Pg), nxv8i16:$Op2, (SVEDup0), SETEQ)), (nxv8i16 (splat_vector (i32 1))), (nxv8i16 (splat_vector (i32 0))))), (CNOT_ZPmZ_H $Op2, $Pg, $Op2)>; - def : Pat<(nxv4i32 (zext (nxv4i1 (AArch64setcc_z (nxv4i1 (SVEAllActive):$Pg), nxv4i32:$Op2, (SVEDup0), SETEQ)))), + def : Pat<(nxv4i32 (vselect (nxv4i1 (AArch64setcc_z (nxv4i1 (SVEAllActive):$Pg), nxv4i32:$Op2, (SVEDup0), SETEQ)), (nxv4i32 (splat_vector (i32 1))), (nxv4i32 (splat_vector (i32 0))))), (CNOT_ZPmZ_S $Op2, $Pg, $Op2)>; - def : Pat<(nxv2i64 (zext (nxv2i1 (AArch64setcc_z (nxv2i1 (SVEAllActive):$Pg), nxv2i64:$Op2, (SVEDup0), SETEQ)))), + def : Pat<(nxv2i64 (vselect (nxv2i1 (AArch64setcc_z (nxv2i1 (SVEAllActive):$Pg), nxv2i64:$Op2, (SVEDup0), SETEQ)), (nxv2i64 (splat_vector (i64 1))), (nxv2i64 (splat_vector (i64 0))))), (CNOT_ZPmZ_D $Op2, $Pg, $Op2)>; defm SMAX_ZPmZ : sve_int_bin_pred_arit_1<0b000, "smax", "SMAX_ZPZZ", AArch64smax_m1, DestructiveBinaryComm>; diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1853,6 +1853,12 @@ { ISD::BITCAST, MVT::nxv2i16, MVT::nxv2f16, 0 }, { ISD::BITCAST, MVT::nxv4i16, MVT::nxv4f16, 0 }, { ISD::BITCAST, MVT::nxv2i32, MVT::nxv2f32, 0 }, + + // Zero extends from nxvmi1 to nxvmiN. + { ISD::ZERO_EXTEND, MVT::nxv2i64, MVT::nxv2i1, 1 }, + { ISD::ZERO_EXTEND, MVT::nxv4i32, MVT::nxv4i1, 1 }, + { ISD::ZERO_EXTEND, MVT::nxv8i16, MVT::nxv8i1, 1 }, + { ISD::ZERO_EXTEND, MVT::nxv16i8, MVT::nxv16i1, 1 }, }; if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD, diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -4702,8 +4702,6 @@ (ins PPRAny:$Pg, cpyimm:$imm)>; def : InstAlias<"mov $Zd, $Pg/z, $imm", (!cast(NAME) zprty:$Zd, PPRAny:$Pg, cpyimm:$imm), 1>; - def : Pat<(intty (zext (predty PPRAny:$Ps1))), - (!cast(NAME) PPRAny:$Ps1, 1, 0)>; def : Pat<(intty (sext (predty PPRAny:$Ps1))), (!cast(NAME) PPRAny:$Ps1, -1, 0)>; def : Pat<(intty (anyext (predty PPRAny:$Ps1))), diff --git a/llvm/test/CodeGen/AArch64/predicated-add-sub-mul.ll b/llvm/test/CodeGen/AArch64/predicated-add-sub-mul.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/predicated-add-sub-mul.ll @@ -0,0 +1,224 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s | FileCheck %s + +target triple = "aarch64-unknown-linux" + +define @zext.add2( %a0, %a1) #0 { +; CHECK-LABEL: zext.add2: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z2.d, z0.d +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: and z1.d, z1.d, #0xffffffff +; CHECK-NEXT: and z2.d, z2.d, #0xffffffff +; CHECK-NEXT: cmphi p0.d, p0/z, z2.d, z1.d +; CHECK-NEXT: mov z1.d, #1 // =0x1 +; CHECK-NEXT: add z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = zext %v to + %result = add %zero.extend, %a0 + ret %result +} + +define @zext.add4( %a0, %a1) #0 { +; CHECK-LABEL: zext.add4: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p0.s, p0/z, z0.s, z1.s +; CHECK-NEXT: mov z1.s, #1 // =0x1 +; CHECK-NEXT: add z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = zext %v to + %result = add %zero.extend, %a0 + ret %result +} + +define @zext.add8( %a0, %a1) #0 { +; CHECK-LABEL: zext.add8: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p1.s, p0/z, z1.s, z3.s +; CHECK-NEXT: cmphi p0.s, p0/z, z0.s, z2.s +; CHECK-NEXT: mov z2.s, #1 // =0x1 +; CHECK-NEXT: add z0.s, p0/m, z0.s, z2.s +; CHECK-NEXT: add z1.s, p1/m, z1.s, z2.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = zext %v to + %result = add %zero.extend, %a0 + ret %result +} + +define @zext.add16( %a0, %a1) #0 { +; CHECK-LABEL: zext.add16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p1.s, p0/z, z3.s, z7.s +; CHECK-NEXT: cmphi p2.s, p0/z, z2.s, z6.s +; CHECK-NEXT: cmphi p3.s, p0/z, z1.s, z5.s +; CHECK-NEXT: cmphi p0.s, p0/z, z0.s, z4.s +; CHECK-NEXT: mov z4.s, #1 // =0x1 +; CHECK-NEXT: add z0.s, p0/m, z0.s, z4.s +; CHECK-NEXT: add z1.s, p3/m, z1.s, z4.s +; CHECK-NEXT: add z2.s, p2/m, z2.s, z4.s +; CHECK-NEXT: add z3.s, p1/m, z3.s, z4.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = zext %v to + %result = add %zero.extend, %a0 + ret %result +} + +define @zext.sub2( %a0, %a1) #0 { +; CHECK-LABEL: zext.sub2: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z2.d, z0.d +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: and z1.d, z1.d, #0xffffffff +; CHECK-NEXT: and z2.d, z2.d, #0xffffffff +; CHECK-NEXT: cmphi p0.d, p0/z, z2.d, z1.d +; CHECK-NEXT: mov z1.d, p0/z, #1 // =0x1 +; CHECK-NEXT: sub z0.d, z1.d, z0.d +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = zext %v to + %result = sub %zero.extend, %a0 + ret %result +} + +define @zext.sub4( %a0, %a1) #0 { +; CHECK-LABEL: zext.sub4: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p0.s, p0/z, z0.s, z1.s +; CHECK-NEXT: mov z1.s, p0/z, #1 // =0x1 +; CHECK-NEXT: sub z0.s, z1.s, z0.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = zext %v to + %result = sub %zero.extend, %a0 + ret %result +} + +define @zext.sub8( %a0, %a1) #0 { +; CHECK-LABEL: zext.sub8: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p1.s, p0/z, z0.s, z2.s +; CHECK-NEXT: cmphi p0.s, p0/z, z1.s, z3.s +; CHECK-NEXT: mov z2.s, p0/z, #1 // =0x1 +; CHECK-NEXT: mov z3.s, p1/z, #1 // =0x1 +; CHECK-NEXT: sub z0.s, z3.s, z0.s +; CHECK-NEXT: sub z1.s, z2.s, z1.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = zext %v to + %result = sub %zero.extend, %a0 + ret %result +} + +define @zext.sub16( %a0, %a1) #0 { +; CHECK-LABEL: zext.sub16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p1.s, p0/z, z2.s, z6.s +; CHECK-NEXT: cmphi p2.s, p0/z, z0.s, z4.s +; CHECK-NEXT: cmphi p3.s, p0/z, z1.s, z5.s +; CHECK-NEXT: cmphi p0.s, p0/z, z3.s, z7.s +; CHECK-NEXT: mov z4.s, p0/z, #1 // =0x1 +; CHECK-NEXT: mov z5.s, p1/z, #1 // =0x1 +; CHECK-NEXT: mov z6.s, p3/z, #1 // =0x1 +; CHECK-NEXT: mov z7.s, p2/z, #1 // =0x1 +; CHECK-NEXT: sub z0.s, z7.s, z0.s +; CHECK-NEXT: sub z1.s, z6.s, z1.s +; CHECK-NEXT: sub z2.s, z5.s, z2.s +; CHECK-NEXT: sub z3.s, z4.s, z3.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = zext %v to + %result = sub %zero.extend, %a0 + ret %result +} + +define @zext.mul2( %a0, %a1) #0 { +; CHECK-LABEL: zext.mul2: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z2.d, z0.d +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: and z1.d, z1.d, #0xffffffff +; CHECK-NEXT: and z2.d, z2.d, #0xffffffff +; CHECK-NEXT: cmphi p1.d, p0/z, z2.d, z1.d +; CHECK-NEXT: mov z1.d, p1/z, #1 // =0x1 +; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = zext %v to + %result = mul %zero.extend, %a0 + ret %result +} + +define @zext.mul4( %a0, %a1) #0 { +; CHECK-LABEL: zext.mul4: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p1.s, p0/z, z0.s, z1.s +; CHECK-NEXT: mov z1.s, p1/z, #1 // =0x1 +; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = zext %v to + %result = mul %zero.extend, %a0 + ret %result +} + +define @zext.mul8( %a0, %a1) #0 { +; CHECK-LABEL: zext.mul8: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p1.s, p0/z, z0.s, z2.s +; CHECK-NEXT: cmphi p2.s, p0/z, z1.s, z3.s +; CHECK-NEXT: mov z2.s, p2/z, #1 // =0x1 +; CHECK-NEXT: mov z3.s, p1/z, #1 // =0x1 +; CHECK-NEXT: mul z0.s, p0/m, z0.s, z3.s +; CHECK-NEXT: mul z1.s, p0/m, z1.s, z2.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = zext %v to + %result = mul %zero.extend, %a0 + ret %result +} + +define @zext.mul16( %a0, %a1) #0 { +; CHECK-LABEL: zext.mul16: +; CHECK: // %bb.0: +; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: .cfi_def_cfa_offset 16 +; CHECK-NEXT: .cfi_offset w29, -16 +; CHECK-NEXT: addvl sp, sp, #-1 +; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: str p4, [sp, #7, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: cmphi p4.s, p0/z, z3.s, z7.s +; CHECK-NEXT: cmphi p2.s, p0/z, z0.s, z4.s +; CHECK-NEXT: mov z4.s, p4/z, #1 // =0x1 +; CHECK-NEXT: ldr p4, [sp, #7, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: cmphi p1.s, p0/z, z2.s, z6.s +; CHECK-NEXT: cmphi p3.s, p0/z, z1.s, z5.s +; CHECK-NEXT: mov z5.s, p1/z, #1 // =0x1 +; CHECK-NEXT: mov z6.s, p3/z, #1 // =0x1 +; CHECK-NEXT: mov z7.s, p2/z, #1 // =0x1 +; CHECK-NEXT: mul z1.s, p0/m, z1.s, z6.s +; CHECK-NEXT: mul z0.s, p0/m, z0.s, z7.s +; CHECK-NEXT: mul z2.s, p0/m, z2.s, z5.s +; CHECK-NEXT: mul z3.s, p0/m, z3.s, z4.s +; CHECK-NEXT: addvl sp, sp, #1 +; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = zext %v to + %result = mul %zero.extend, %a0 + ret %result +} + +attributes #0 = { "target-features"="+sve" }