diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -5409,6 +5409,33 @@ } } + // logic_op (vselect (A, splat_vector(-1), splat_vector(0)), + // vselect (B, splat_vector(-1), splat_vector(0))) --> + // vselect (logic_op (A, B), splat_vector(-1), splat_vector(0)) + // where A and B are i1 scalar vectors + APInt SplatVal; + if (HandOpcode == ISD::VSELECT && + N0->getOperand(0)->getValueType(0).isScalableVector() && + N0->getOperand(0)->getValueType(0).getScalarType() == MVT::i1 && + ISD::isConstantSplatVectorAllZeros(N0->getOperand(2).getNode()) && + ISD::isConstantSplatVector(N0->getOperand(1).getNode(), SplatVal) && + SplatVal.getSExtValue() == -1 && + N1->getOperand(0)->getValueType(0).isScalableVector() && + N1->getOperand(0)->getValueType(0).getScalarType() == MVT::i1 && + ISD::isConstantSplatVectorAllZeros(N1->getOperand(2).getNode()) && + ISD::isConstantSplatVector(N1->getOperand(1).getNode(), SplatVal) && + SplatVal.getSExtValue() == -1) { + + // If both operands have other uses, this transform would create extra + // instructions without eliminating anything. + if (!N0.hasOneUse() && !N1.hasOneUse()) + return SDValue(); + + SDValue Logic = DAG.getNode(LogicOpcode, DL, XVT, X, Y); + return DAG.getNode(ISD::VSELECT, DL, VT, Logic, N0.getOperand(1), + N0.getOperand(2)); + } + return SDValue(); } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -927,6 +927,7 @@ SDValue LowerStore128(SDValue Op, SelectionDAG &DAG) const; SDValue LowerABS(SDValue Op, SelectionDAG &DAG) const; SDValue LowerZERO_EXTEND(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerSIGN_EXTEND(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1203,6 +1203,7 @@ setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom); setOperationAction(ISD::ZERO_EXTEND, VT, Custom); + setOperationAction(ISD::SIGN_EXTEND, VT, Custom); setOperationAction(ISD::UMUL_LOHI, VT, Expand); setOperationAction(ISD::SMUL_LOHI, VT, Expand); @@ -5542,6 +5543,26 @@ return DAG.getNode(ISD::VSELECT, DL, VT, Value, Ones, Zeros); } +SDValue AArch64TargetLowering::LowerSIGN_EXTEND(SDValue Op, SelectionDAG &DAG) const { + assert(Op->getOpcode() == ISD::SIGN_EXTEND && "Expected SIGN_EXTEND"); + + if (Op.getValueType().isFixedLengthVector()) + return LowerFixedLengthVectorIntExtendToSVE(Op, DAG); + + // Try to lower to VSELECT to allow zext to transform into + // a predicated instruction like add, sub or mul. + SDValue Value = Op->getOperand(0); + if (!Value->getValueType(0).isScalableVector() || + Value->getValueType(0).getScalarType() != MVT::i1) + return SDValue(); + + SDLoc DL = SDLoc(Op); + EVT VT = Op->getValueType(0); + SDValue MinusOnes = DAG.getConstant(APInt(VT.getScalarType().getScalarSizeInBits(), -1, true), DL, VT); + SDValue Zeros = DAG.getConstant(APInt(VT.getScalarType().getScalarSizeInBits(), 0, true), DL, VT); + return DAG.getNode(ISD::VSELECT, DL, VT, Value, MinusOnes, Zeros); +} + SDValue AArch64TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { LLVM_DEBUG(dbgs() << "Custom lowering: "); @@ -5753,8 +5774,9 @@ case ISD::VSCALE: return LowerVSCALE(Op, DAG); case ISD::ANY_EXTEND: - case ISD::SIGN_EXTEND: return LowerFixedLengthVectorIntExtendToSVE(Op, DAG); + case ISD::SIGN_EXTEND: + return LowerSIGN_EXTEND(Op, DAG); case ISD::ZERO_EXTEND: return LowerZERO_EXTEND(Op, DAG); case ISD::SIGN_EXTEND_INREG: { @@ -18332,13 +18354,25 @@ return SDValue(); } +static bool isVSELECTLoweredFromSIGN_EXTEND(SDValue N) { + APInt SplatVal; + if (N->getOpcode() != ISD::VSELECT || + !N->getOperand(0)->getValueType(0).isScalableVector() || + N->getOperand(0)->getValueType(0).getScalarType() != MVT::i1 || + !ISD::isConstantSplatVectorAllZeros(N->getOperand(2).getNode()) || + !ISD::isConstantSplatVector(N->getOperand(1).getNode(), SplatVal) || + SplatVal.getSExtValue() != -1) + return false; + return true; +} + static SDValue performSunpkloCombine(SDNode *N, SelectionDAG &DAG) { // sunpklo(sext(pred)) -> sext(extract_low_half(pred)) // This transform works in partnership with performSetCCPunpkCombine to // remove unnecessary transfer of predicates into standard registers and back - if (N->getOperand(0).getOpcode() == ISD::SIGN_EXTEND && + if ((N->getOperand(0).getOpcode() == ISD::SIGN_EXTEND && N->getOperand(0)->getOperand(0)->getValueType(0).getScalarType() == - MVT::i1) { + MVT::i1) || isVSELECTLoweredFromSIGN_EXTEND(N->getOperand(0))) { SDValue CC = N->getOperand(0)->getOperand(0); auto VT = CC->getValueType(0).getHalfNumVectorElementsVT(*DAG.getContext()); SDValue Unpk = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, CC, @@ -19528,7 +19562,7 @@ ISD::CondCode Cond = cast(N->getOperand(3))->get(); if (Cond != ISD::SETNE || !isZerosVector(RHS.getNode()) || - LHS->getOpcode() != ISD::SIGN_EXTEND) + (LHS->getOpcode() != ISD::SIGN_EXTEND && !isVSELECTLoweredFromSIGN_EXTEND(LHS))) return SDValue(); SDValue Extract = LHS->getOperand(0); @@ -19571,7 +19605,7 @@ return V; if (Cond == ISD::SETNE && isZerosVector(RHS.getNode()) && - LHS->getOpcode() == ISD::SIGN_EXTEND && + (LHS->getOpcode() == ISD::SIGN_EXTEND || isVSELECTLoweredFromSIGN_EXTEND(LHS)) && LHS->getOperand(0)->getValueType(0) == N->getValueType(0)) { // setcc_merge_zero( // pred, extend(setcc_merge_zero(pred, ...)), != splat(0)) @@ -22562,12 +22596,13 @@ EVT MaskVT = Op.getOperand(0).getValueType(); EVT MaskContainerVT = getContainerForFixedLengthVector(DAG, MaskVT); auto Mask = convertToScalableVector(DAG, MaskContainerVT, Op.getOperand(0)); - Mask = DAG.getNode(ISD::TRUNCATE, DL, - MaskContainerVT.changeVectorElementType(MVT::i1), Mask); - - auto ScalableRes = DAG.getNode(ISD::VSELECT, DL, ContainerVT, - Mask, Op1, Op2); - + if (isVSELECTLoweredFromSIGN_EXTEND(Mask)) { + Mask = Mask->getOperand(0); + } else { + Mask = DAG.getNode(ISD::TRUNCATE, DL, + MaskContainerVT.changeVectorElementType(MVT::i1), Mask); + } + auto ScalableRes = DAG.getNode(ISD::VSELECT, DL, ContainerVT, Mask, Op1, Op2); return convertFromScalableVector(DAG, VT, ScalableRes); } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1895,6 +1895,12 @@ { ISD::ZERO_EXTEND, MVT::nxv4i32, MVT::nxv4i1, 1 }, { ISD::ZERO_EXTEND, MVT::nxv8i16, MVT::nxv8i1, 1 }, { ISD::ZERO_EXTEND, MVT::nxv16i8, MVT::nxv16i1, 1 }, + + // Sign extends from nxvmi1 to nxvmiN. + { ISD::SIGN_EXTEND, MVT::nxv2i64, MVT::nxv2i1, 1 }, + { ISD::SIGN_EXTEND, MVT::nxv4i32, MVT::nxv4i1, 1 }, + { ISD::SIGN_EXTEND, MVT::nxv8i16, MVT::nxv8i1, 1 }, + { ISD::SIGN_EXTEND, MVT::nxv16i8, MVT::nxv16i1, 1 }, }; if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD, diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -4702,8 +4702,6 @@ (ins PPRAny:$Pg, cpyimm:$imm)>; def : InstAlias<"mov $Zd, $Pg/z, $imm", (!cast(NAME) zprty:$Zd, PPRAny:$Pg, cpyimm:$imm), 1>; - def : Pat<(intty (sext (predty PPRAny:$Ps1))), - (!cast(NAME) PPRAny:$Ps1, -1, 0)>; def : Pat<(intty (anyext (predty PPRAny:$Ps1))), (!cast(NAME) PPRAny:$Ps1, 1, 0)>; def : Pat<(vselect predty:$Pg, diff --git a/llvm/test/CodeGen/AArch64/predicated-add-sub-mul.ll b/llvm/test/CodeGen/AArch64/predicated-add-sub-mul.ll --- a/llvm/test/CodeGen/AArch64/predicated-add-sub-mul.ll +++ b/llvm/test/CodeGen/AArch64/predicated-add-sub-mul.ll @@ -221,4 +221,222 @@ ret %result } +define @sext.add2( %a0, %a1) #0 { +; CHECK-LABEL: sext.add2: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z2.d, z0.d +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: and z1.d, z1.d, #0xffffffff +; CHECK-NEXT: and z2.d, z2.d, #0xffffffff +; CHECK-NEXT: cmphi p0.d, p0/z, z2.d, z1.d +; CHECK-NEXT: mov z1.d, #-1 // =0xffffffffffffffff +; CHECK-NEXT: add z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = sext %v to + %result = add %zero.extend, %a0 + ret %result +} + +define @sext.add4( %a0, %a1) #0 { +; CHECK-LABEL: sext.add4: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p0.s, p0/z, z0.s, z1.s +; CHECK-NEXT: mov z1.s, #-1 // =0xffffffffffffffff +; CHECK-NEXT: add z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = sext %v to + %result = add %zero.extend, %a0 + ret %result +} + +define @sext.add8( %a0, %a1) #0 { +; CHECK-LABEL: sext.add8: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p1.s, p0/z, z1.s, z3.s +; CHECK-NEXT: cmphi p0.s, p0/z, z0.s, z2.s +; CHECK-NEXT: mov z2.s, #-1 // =0xffffffffffffffff +; CHECK-NEXT: add z0.s, p0/m, z0.s, z2.s +; CHECK-NEXT: add z1.s, p1/m, z1.s, z2.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = sext %v to + %result = add %zero.extend, %a0 + ret %result +} + +define @sext.add16( %a0, %a1) #0 { +; CHECK-LABEL: sext.add16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p1.s, p0/z, z3.s, z7.s +; CHECK-NEXT: cmphi p2.s, p0/z, z2.s, z6.s +; CHECK-NEXT: cmphi p3.s, p0/z, z1.s, z5.s +; CHECK-NEXT: cmphi p0.s, p0/z, z0.s, z4.s +; CHECK-NEXT: mov z4.s, #-1 // =0xffffffffffffffff +; CHECK-NEXT: add z0.s, p0/m, z0.s, z4.s +; CHECK-NEXT: add z1.s, p3/m, z1.s, z4.s +; CHECK-NEXT: add z2.s, p2/m, z2.s, z4.s +; CHECK-NEXT: add z3.s, p1/m, z3.s, z4.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = sext %v to + %result = add %zero.extend, %a0 + ret %result +} + +define @sext.sub2( %a0, %a1) #0 { +; CHECK-LABEL: sext.sub2: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z2.d, z0.d +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: and z1.d, z1.d, #0xffffffff +; CHECK-NEXT: and z2.d, z2.d, #0xffffffff +; CHECK-NEXT: cmphi p0.d, p0/z, z2.d, z1.d +; CHECK-NEXT: mov z1.d, p0/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: sub z0.d, z1.d, z0.d +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = sext %v to + %result = sub %zero.extend, %a0 + ret %result +} + +define @sext.sub4( %a0, %a1) #0 { +; CHECK-LABEL: sext.sub4: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p0.s, p0/z, z0.s, z1.s +; CHECK-NEXT: mov z1.s, p0/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: sub z0.s, z1.s, z0.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = sext %v to + %result = sub %zero.extend, %a0 + ret %result +} + +define @sext.sub8( %a0, %a1) #0 { +; CHECK-LABEL: sext.sub8: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p1.s, p0/z, z0.s, z2.s +; CHECK-NEXT: cmphi p0.s, p0/z, z1.s, z3.s +; CHECK-NEXT: mov z2.s, p0/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: mov z3.s, p1/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: sub z0.s, z3.s, z0.s +; CHECK-NEXT: sub z1.s, z2.s, z1.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = sext %v to + %result = sub %zero.extend, %a0 + ret %result +} + +define @sext.sub16( %a0, %a1) #0 { +; CHECK-LABEL: sext.sub16: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p1.s, p0/z, z2.s, z6.s +; CHECK-NEXT: cmphi p2.s, p0/z, z0.s, z4.s +; CHECK-NEXT: cmphi p3.s, p0/z, z1.s, z5.s +; CHECK-NEXT: cmphi p0.s, p0/z, z3.s, z7.s +; CHECK-NEXT: mov z4.s, p0/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: mov z5.s, p1/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: mov z6.s, p3/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: mov z7.s, p2/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: sub z0.s, z7.s, z0.s +; CHECK-NEXT: sub z1.s, z6.s, z1.s +; CHECK-NEXT: sub z2.s, z5.s, z2.s +; CHECK-NEXT: sub z3.s, z4.s, z3.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = sext %v to + %result = sub %zero.extend, %a0 + ret %result +} + +define @sext.mul2( %a0, %a1) #0 { +; CHECK-LABEL: sext.mul2: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z2.d, z0.d +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: and z1.d, z1.d, #0xffffffff +; CHECK-NEXT: and z2.d, z2.d, #0xffffffff +; CHECK-NEXT: cmphi p1.d, p0/z, z2.d, z1.d +; CHECK-NEXT: mov z1.d, p1/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = sext %v to + %result = mul %zero.extend, %a0 + ret %result +} + +define @sext.mul4( %a0, %a1) #0 { +; CHECK-LABEL: sext.mul4: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p1.s, p0/z, z0.s, z1.s +; CHECK-NEXT: mov z1.s, p1/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = sext %v to + %result = mul %zero.extend, %a0 + ret %result +} + +define @sext.mul8( %a0, %a1) #0 { +; CHECK-LABEL: sext.mul8: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: cmphi p1.s, p0/z, z0.s, z2.s +; CHECK-NEXT: cmphi p2.s, p0/z, z1.s, z3.s +; CHECK-NEXT: mov z2.s, p2/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: mov z3.s, p1/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: mul z0.s, p0/m, z0.s, z3.s +; CHECK-NEXT: mul z1.s, p0/m, z1.s, z2.s +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = sext %v to + %result = mul %zero.extend, %a0 + ret %result +} + +define @sext.mul16( %a0, %a1) #0 { +; CHECK-LABEL: sext.mul16: +; CHECK: // %bb.0: +; CHECK-NEXT: str x29, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: .cfi_def_cfa_offset 16 +; CHECK-NEXT: .cfi_offset w29, -16 +; CHECK-NEXT: addvl sp, sp, #-1 +; CHECK-NEXT: .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: str p4, [sp, #7, mul vl] // 2-byte Folded Spill +; CHECK-NEXT: cmphi p4.s, p0/z, z3.s, z7.s +; CHECK-NEXT: cmphi p2.s, p0/z, z0.s, z4.s +; CHECK-NEXT: mov z4.s, p4/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: ldr p4, [sp, #7, mul vl] // 2-byte Folded Reload +; CHECK-NEXT: cmphi p1.s, p0/z, z2.s, z6.s +; CHECK-NEXT: cmphi p3.s, p0/z, z1.s, z5.s +; CHECK-NEXT: mov z5.s, p1/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: mov z6.s, p3/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: mov z7.s, p2/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: mul z1.s, p0/m, z1.s, z6.s +; CHECK-NEXT: mul z0.s, p0/m, z0.s, z7.s +; CHECK-NEXT: mul z2.s, p0/m, z2.s, z5.s +; CHECK-NEXT: mul z3.s, p0/m, z3.s, z4.s +; CHECK-NEXT: addvl sp, sp, #1 +; CHECK-NEXT: ldr x29, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret + %v = icmp ugt %a0, %a1 + %zero.extend = sext %v to + %result = mul %zero.extend, %a0 + ret %result +} + attributes #0 = { "target-features"="+sve" }