diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -20493,7 +20493,8 @@ EVT N0SrcSVT = N0Src.getValueType().getScalarType(); EVT N1SrcSVT = N1Src.getValueType().getScalarType(); if ((N0.isUndef() || N0SrcSVT == N1SrcSVT) && - N0Src.getValueType().isVector() && N1Src.getValueType().isVector()) { + N0Src.getValueType().isFixedLengthVector() && + N1Src.getValueType().isFixedLengthVector()) { EVT NewVT; SDLoc DL(N); SDValue NewIdx; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -74,10 +74,20 @@ // Arithmetic instructions ADD_PRED, + AND_PRED, + EOR_PRED, FADD_PRED, + FDIV_PRED, + FMA_PRED, + FMAXNM_PRED, + FMINNM_PRED, + FMUL_PRED, + FSUB_PRED, + MUL_PRED, + ORR_PRED, SDIV_PRED, + SUB_PRED, UDIV_PRED, - FMA_PRED, SMIN_MERGE_OP1, UMIN_MERGE_OP1, SMAX_MERGE_OP1, @@ -891,6 +901,7 @@ EVT VT, SelectionDAG &DAG, const SDLoc &DL) const; SDValue LowerFixedLengthVectorLoadToSVE(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerFixedLengthVectorSetccToSVE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerFixedLengthVectorStoreToSVE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerFixedLengthVectorTruncateToSVE(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1069,10 +1069,34 @@ // Lower fixed length vector operations to scalable equivalents. setOperationAction(ISD::ADD, VT, Custom); + setOperationAction(ISD::AND, VT, Custom); setOperationAction(ISD::FADD, VT, Custom); + setOperationAction(ISD::FDIV, VT, Custom); + setOperationAction(ISD::FMAXNUM, VT, Custom); + setOperationAction(ISD::FMINNUM, VT, Custom); + setOperationAction(ISD::FMUL, VT, Custom); + setOperationAction(ISD::FSUB, VT, Custom); setOperationAction(ISD::LOAD, VT, Custom); + setOperationAction(ISD::MUL, VT, Custom); + setOperationAction(ISD::OR, VT, Custom); + setOperationAction(ISD::SETCC, VT, Custom); + setOperationAction(ISD::SHL, VT, Custom); + setOperationAction(ISD::SMAX, VT, Custom); + setOperationAction(ISD::SMIN, VT, Custom); + setOperationAction(ISD::SRA, VT, Custom); + setOperationAction(ISD::SRL, VT, Custom); setOperationAction(ISD::STORE, VT, Custom); + setOperationAction(ISD::SUB, VT, Custom); setOperationAction(ISD::TRUNCATE, VT, Custom); + setOperationAction(ISD::UMAX, VT, Custom); + setOperationAction(ISD::UMIN, VT, Custom); + setOperationAction(ISD::XOR, VT, Custom); + + if (VT.getVectorElementType() == MVT::i32 || + VT.getVectorElementType() == MVT::i64) { + setOperationAction(ISD::SDIV, VT, Custom); + setOperationAction(ISD::UDIV, VT, Custom); + } } void AArch64TargetLowering::addDRTypeForNEON(MVT VT) { @@ -1385,7 +1409,12 @@ MAKE_CASE(AArch64ISD::THREAD_POINTER) MAKE_CASE(AArch64ISD::TLSDESC_CALLSEQ) MAKE_CASE(AArch64ISD::ADD_PRED) + MAKE_CASE(AArch64ISD::AND_PRED) + MAKE_CASE(AArch64ISD::EOR_PRED) + MAKE_CASE(AArch64ISD::MUL_PRED) + MAKE_CASE(AArch64ISD::ORR_PRED) MAKE_CASE(AArch64ISD::SDIV_PRED) + MAKE_CASE(AArch64ISD::SUB_PRED) MAKE_CASE(AArch64ISD::UDIV_PRED) MAKE_CASE(AArch64ISD::SMIN_MERGE_OP1) MAKE_CASE(AArch64ISD::UMIN_MERGE_OP1) @@ -1483,11 +1512,16 @@ MAKE_CASE(AArch64ISD::FADD_PRED) MAKE_CASE(AArch64ISD::FADDA_PRED) MAKE_CASE(AArch64ISD::FADDV_PRED) + MAKE_CASE(AArch64ISD::FDIV_PRED) MAKE_CASE(AArch64ISD::FMA_PRED) MAKE_CASE(AArch64ISD::FMAXV_PRED) + MAKE_CASE(AArch64ISD::FMAXNM_PRED) MAKE_CASE(AArch64ISD::FMAXNMV_PRED) MAKE_CASE(AArch64ISD::FMINV_PRED) + MAKE_CASE(AArch64ISD::FMINNM_PRED) MAKE_CASE(AArch64ISD::FMINNMV_PRED) + MAKE_CASE(AArch64ISD::FMUL_PRED) + MAKE_CASE(AArch64ISD::FSUB_PRED) MAKE_CASE(AArch64ISD::NOT) MAKE_CASE(AArch64ISD::BIT) MAKE_CASE(AArch64ISD::CBZ) @@ -3472,12 +3506,18 @@ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FADD_PRED); return LowerF128Call(Op, DAG, RTLIB::ADD_F128); case ISD::FSUB: + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED); return LowerF128Call(Op, DAG, RTLIB::SUB_F128); case ISD::FMUL: + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); return LowerF128Call(Op, DAG, RTLIB::MUL_F128); case ISD::FMA: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED); case ISD::FDIV: + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED); return LowerF128Call(Op, DAG, RTLIB::DIV_F128); case ISD::FP_ROUND: case ISD::STRICT_FP_ROUND: @@ -3534,6 +3574,8 @@ case ISD::OR: return LowerVectorOR(Op, DAG); case ISD::XOR: + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::EOR_PRED); return LowerXOR(Op, DAG); case ISD::PREFETCH: return LowerPREFETCH(Op, DAG); @@ -3552,6 +3594,8 @@ case ISD::FLT_ROUNDS_: return LowerFLT_ROUNDS_(Op, DAG); case ISD::MUL: + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED); return LowerMUL(Op, DAG); case ISD::INTRINSIC_WO_CHAIN: return LowerINTRINSIC_WO_CHAIN(Op, DAG); @@ -3583,6 +3627,22 @@ if (useSVEForFixedLengthVectorVT(Op.getValueType())) return LowerToPredicatedOp(Op, DAG, AArch64ISD::ADD_PRED); llvm_unreachable("Unexpected request to lower ISD::ADD"); + case ISD::SUB: + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::SUB_PRED); + llvm_unreachable("Unexpected request to lower ISD::SUB"); + case ISD::AND: + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::AND_PRED); + llvm_unreachable("Unexpected request to lower ISD::ORR"); + case ISD::FMAXNUM: + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAXNM_PRED); + llvm_unreachable("Unexpected request to lower ISD::FMAXNUM"); + case ISD::FMINNUM: + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMINNM_PRED); + llvm_unreachable("Unexpected request to lower ISD::FMINNUM"); } } @@ -8246,6 +8306,9 @@ SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op, SelectionDAG &DAG) const { + if (useSVEForFixedLengthVectorVT(Op.getValueType())) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::ORR_PRED); + // Attempt to form a vector S[LR]I from (or (and X, C1), (lsl Y, C2)) if (SDValue Res = tryLowerToSLI(Op.getNode(), DAG)) return Res; @@ -8936,7 +8999,7 @@ llvm_unreachable("unexpected shift opcode"); case ISD::SHL: - if (VT.isScalableVector()) + if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT)) return LowerToPredicatedOp(Op, DAG, AArch64ISD::SHL_MERGE_OP1); if (isVShiftLImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize) @@ -8948,7 +9011,7 @@ Op.getOperand(0), Op.getOperand(1)); case ISD::SRA: case ISD::SRL: - if (VT.isScalableVector()) { + if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT)) { unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_MERGE_OP1 : AArch64ISD::SRL_MERGE_OP1; return LowerToPredicatedOp(Op, DAG, Opc); @@ -9082,6 +9145,9 @@ return LowerToPredicatedOp(Op, DAG, AArch64ISD::SETCC_MERGE_ZERO); } + if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType())) + return LowerFixedLengthVectorSetccToSVE(Op, DAG); + ISD::CondCode CC = cast(Op.getOperand(2))->get(); SDValue LHS = Op.getOperand(0); SDValue RHS = Op.getOperand(1); @@ -11064,6 +11130,10 @@ if (VT.isScalableVector()) return performSVEAndCombine(N, DCI); + // TODO: useSVEForFixedLengthVectorVT? + if (VT.getSizeInBits() > 128) + return SDValue(); + BuildVectorSDNode *BVN = dyn_cast(N->getOperand(1).getNode()); if (!BVN) @@ -15154,3 +15224,26 @@ return DAG.getNode(NewOp, DL, VT, Operands); } + +SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE( + SDValue Op, SelectionDAG &DAG) const { + SDLoc DL(Op); + EVT InVT = Op.getOperand(0).getValueType(); + EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT); + + // Expand floating point vector comparisons. + if (InVT.isFloatingPoint()) + return SDValue(); + + auto Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0)); + auto Op2 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1)); + auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT); + + EVT CmpVT = Pg.getValueType(); + SmallVector CmpOps = {Pg, Op1, Op2, Op.getOperand(2)}; + auto Cmp = DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, CmpVT, CmpOps); + + auto Promote = DAG.getBoolExtOrTrunc(Cmp, DL, ContainerVT, InVT); + auto Extract = convertFromScalableVector(DAG, InVT, Promote); + return DAG.getNode(ISD::TRUNCATE, DL, Op.getValueType(), Extract); +} diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -173,11 +173,21 @@ ]>; // Predicated operations with the result of inactive lanes being unspecified. -def AArch64add_p : SDNode<"AArch64ISD::ADD_PRED", SDT_AArch64Arith>; -def AArch64fadd_p : SDNode<"AArch64ISD::FADD_PRED", SDT_AArch64Arith>; -def AArch64fma_p : SDNode<"AArch64ISD::FMA_PRED", SDT_AArch64FMA>; -def AArch64sdiv_p : SDNode<"AArch64ISD::SDIV_PRED", SDT_AArch64Arith>; -def AArch64udiv_p : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64Arith>; +def AArch64add_p : SDNode<"AArch64ISD::ADD_PRED", SDT_AArch64Arith>; +def AArch64and_p : SDNode<"AArch64ISD::AND_PRED", SDT_AArch64Arith>; +def AArch64eor_p : SDNode<"AArch64ISD::EOR_PRED", SDT_AArch64Arith>; +def AArch64fadd_p : SDNode<"AArch64ISD::FADD_PRED", SDT_AArch64Arith>; +def AArch64fdiv_p : SDNode<"AArch64ISD::FDIV_PRED", SDT_AArch64Arith>; +def AArch64fma_p : SDNode<"AArch64ISD::FMA_PRED", SDT_AArch64FMA>; +def AArch64fmaxnm_p : SDNode<"AArch64ISD::FMAXNM_PRED", SDT_AArch64Arith>; +def AArch64fminnm_p : SDNode<"AArch64ISD::FMINNM_PRED", SDT_AArch64Arith>; +def AArch64fmul_p : SDNode<"AArch64ISD::FMUL_PRED", SDT_AArch64Arith>; +def AArch64fsub_p : SDNode<"AArch64ISD::FSUB_PRED", SDT_AArch64Arith>; +def AArch64mul_p : SDNode<"AArch64ISD::MUL_PRED", SDT_AArch64Arith>; +def AArch64orr_p : SDNode<"AArch64ISD::ORR_PRED", SDT_AArch64Arith>; +def AArch64sdiv_p : SDNode<"AArch64ISD::SDIV_PRED", SDT_AArch64Arith>; +def AArch64sub_p : SDNode<"AArch64ISD::SUB_PRED", SDT_AArch64Arith>; +def AArch64udiv_p : SDNode<"AArch64ISD::UDIV_PRED", SDT_AArch64Arith>; // Merging op1 into the inactive lanes. def AArch64smin_m1 : SDNode<"AArch64ISD::SMIN_MERGE_OP1", SDT_AArch64Arith>; @@ -231,6 +241,7 @@ defm SUBR_ZPmZ : sve_int_bin_pred_arit_0<0b011, "subr", "SUBR_ZPZZ", int_aarch64_sve_subr, DestructiveBinaryCommWithRev, "SUB_ZPmZ", /*isReverseInstr*/ 1>; defm ADD_ZPZZ : sve_int_bin_pred_bhsd; + defm SUB_ZPZZ : sve_int_bin_pred_bhsd; let Predicates = [HasSVE, UseExperimentalZeroingPseudos] in { defm ADD_ZPZZ : sve_int_bin_pred_zeroing_bhsd; @@ -238,10 +249,14 @@ defm SUBR_ZPZZ : sve_int_bin_pred_zeroing_bhsd; } - defm ORR_ZPmZ : sve_int_bin_pred_log<0b000, "orr", int_aarch64_sve_orr>; - defm EOR_ZPmZ : sve_int_bin_pred_log<0b001, "eor", int_aarch64_sve_eor>; - defm AND_ZPmZ : sve_int_bin_pred_log<0b010, "and", int_aarch64_sve_and>; - defm BIC_ZPmZ : sve_int_bin_pred_log<0b011, "bic", int_aarch64_sve_bic>; + defm ORR_ZPmZ : sve_int_bin_pred_log<0b000, "orr", "ORR_ZPZZ", int_aarch64_sve_orr, DestructiveBinaryComm>; + defm EOR_ZPmZ : sve_int_bin_pred_log<0b001, "eor", "EOR_ZPZZ", int_aarch64_sve_eor, DestructiveBinaryComm>; + defm AND_ZPmZ : sve_int_bin_pred_log<0b010, "and", "AND_ZPZZ", int_aarch64_sve_and, DestructiveBinaryComm>; + defm BIC_ZPmZ : sve_int_bin_pred_log<0b011, "bic", "BIC_ZPZZ", int_aarch64_sve_bic, DestructiveOther>; + + defm AND_ZPZZ : sve_int_bin_pred_bhsd; + defm EOR_ZPZZ : sve_int_bin_pred_bhsd; + defm ORR_ZPZZ : sve_int_bin_pred_bhsd; defm ADD_ZI : sve_int_arith_imm0<0b000, "add", add, null_frag>; defm SUB_ZI : sve_int_arith_imm0<0b001, "sub", sub, null_frag>; @@ -277,9 +292,11 @@ defm UMIN_ZI : sve_int_arith_imm1_unsigned<0b11, "umin", AArch64umin_m1>; defm MUL_ZI : sve_int_arith_imm2<"mul", mul>; - defm MUL_ZPmZ : sve_int_bin_pred_arit_2<0b000, "mul", int_aarch64_sve_mul>; - defm SMULH_ZPmZ : sve_int_bin_pred_arit_2<0b010, "smulh", int_aarch64_sve_smulh>; - defm UMULH_ZPmZ : sve_int_bin_pred_arit_2<0b011, "umulh", int_aarch64_sve_umulh>; + defm MUL_ZPmZ : sve_int_bin_pred_arit_2<0b000, "mul", "MUL_ZPZZ", int_aarch64_sve_mul, DestructiveBinaryComm>; + defm SMULH_ZPmZ : sve_int_bin_pred_arit_2<0b010, "smulh", "SMULH_ZPZZ", int_aarch64_sve_smulh, DestructiveBinaryComm>; + defm UMULH_ZPmZ : sve_int_bin_pred_arit_2<0b011, "umulh", "UMULH_ZPZZ", int_aarch64_sve_umulh, DestructiveBinaryComm>; + + defm MUL_ZPZZ : sve_int_bin_pred_bhsd; // Add unpredicated alternative for the mul instruction. def : Pat<(mul nxv16i8:$Op1, nxv16i8:$Op2), @@ -361,6 +378,11 @@ defm FDIV_ZPmZ : sve_fp_2op_p_zds<0b1101, "fdiv", "FDIV_ZPZZ", int_aarch64_sve_fdiv, DestructiveBinaryCommWithRev, "FDIVR_ZPmZ">; defm FADD_ZPZZ : sve_fp_bin_pred_hfd; + defm FSUB_ZPZZ : sve_fp_bin_pred_hfd; + defm FMUL_ZPZZ : sve_fp_bin_pred_hfd; + defm FMAXNM_ZPZZ : sve_fp_bin_pred_hfd; + defm FMINNM_ZPZZ : sve_fp_bin_pred_hfd; + defm FDIV_ZPZZ : sve_fp_bin_pred_hfd; let Predicates = [HasSVE, UseExperimentalZeroingPseudos] in { defm FADD_ZPZZ : sve_fp_2op_p_zds_zeroing_hsd; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -2339,11 +2339,19 @@ let ElementSize = zprty.ElementSize; } -multiclass sve_int_bin_pred_log opc, string asm, SDPatternOperator op> { - def _B : sve_int_bin_pred_arit_log<0b00, 0b11, opc, asm, ZPR8>; - def _H : sve_int_bin_pred_arit_log<0b01, 0b11, opc, asm, ZPR16>; - def _S : sve_int_bin_pred_arit_log<0b10, 0b11, opc, asm, ZPR32>; - def _D : sve_int_bin_pred_arit_log<0b11, 0b11, opc, asm, ZPR64>; +multiclass sve_int_bin_pred_log opc, string asm, string Ps, + SDPatternOperator op, + DestructiveInstTypeEnum flags> { + let DestructiveInstType = flags in { + def _B : sve_int_bin_pred_arit_log<0b00, 0b11, opc, asm, ZPR8>, + SVEPseudo2Instr; + def _H : sve_int_bin_pred_arit_log<0b01, 0b11, opc, asm, ZPR16>, + SVEPseudo2Instr; + def _S : sve_int_bin_pred_arit_log<0b10, 0b11, opc, asm, ZPR32>, + SVEPseudo2Instr; + def _D : sve_int_bin_pred_arit_log<0b11, 0b11, opc, asm, ZPR64>, + SVEPseudo2Instr; + } def : SVE_3_Op_Pat(NAME # _B)>; def : SVE_3_Op_Pat(NAME # _H)>; @@ -2384,11 +2392,19 @@ def : SVE_3_Op_Pat(NAME # _D)>; } -multiclass sve_int_bin_pred_arit_2 opc, string asm, SDPatternOperator op> { - def _B : sve_int_bin_pred_arit_log<0b00, 0b10, opc, asm, ZPR8>; - def _H : sve_int_bin_pred_arit_log<0b01, 0b10, opc, asm, ZPR16>; - def _S : sve_int_bin_pred_arit_log<0b10, 0b10, opc, asm, ZPR32>; - def _D : sve_int_bin_pred_arit_log<0b11, 0b10, opc, asm, ZPR64>; +multiclass sve_int_bin_pred_arit_2 opc, string asm, string Ps, + SDPatternOperator op, + DestructiveInstTypeEnum flags> { + let DestructiveInstType = flags in { + def _B : sve_int_bin_pred_arit_log<0b00, 0b10, opc, asm, ZPR8>, + SVEPseudo2Instr; + def _H : sve_int_bin_pred_arit_log<0b01, 0b10, opc, asm, ZPR16>, + SVEPseudo2Instr; + def _S : sve_int_bin_pred_arit_log<0b10, 0b10, opc, asm, ZPR32>, + SVEPseudo2Instr; + def _D : sve_int_bin_pred_arit_log<0b11, 0b10, opc, asm, ZPR64>, + SVEPseudo2Instr; + } def : SVE_3_Op_Pat(NAME # _B)>; def : SVE_3_Op_Pat(NAME # _H)>;